Sharp-It / zero123plus /pipeline.py
YiftachEde's picture
updated
d6502a4
from typing import Any, Dict, Optional
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
import numpy
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torch.distributed
import transformers
from collections import OrderedDict
from PIL import Image
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
randn_tensor = torch.randn
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
UNet2DConditionModel,
ImagePipelineOutput,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
XFormersAttnProcessor,
AttnProcessor2_0,
)
from diffusers.utils.import_utils import is_xformers_available
import spaces
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def to_rgb_image(maybe_rgba: Image.Image):
if maybe_rgba.mode == "RGB":
return maybe_rgba
elif maybe_rgba.mode == "RGBA":
rgba = maybe_rgba
img = numpy.random.randint(
255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8
)
img = Image.fromarray(img, "RGB")
img.paste(rgba, mask=rgba.getchannel("A"))
return img
else:
raise ValueError("Unsupported image type.", maybe_rgba.mode)
class ReferenceOnlyAttnProc(torch.nn.Module):
def __init__(self, chained_proc, enabled=False, name=None) -> None:
super().__init__()
self.enabled = enabled
self.chained_proc = chained_proc
self.name = name
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
mode="w",
ref_dict: dict = None,
is_cfg_guidance=False,
) -> Any:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if self.enabled and is_cfg_guidance:
res0 = self.chained_proc(
attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask
)
hidden_states = hidden_states[1:]
encoder_hidden_states = encoder_hidden_states[1:]
if self.enabled:
if mode == "w":
ref_dict[self.name] = encoder_hidden_states
elif mode == "r":
encoder_hidden_states = torch.cat(
[encoder_hidden_states, ref_dict.pop(self.name)], dim=1
)
elif mode == "m":
encoder_hidden_states = torch.cat(
[encoder_hidden_states, ref_dict[self.name]], dim=1
)
elif mode == "c":
encoder_hidden_states = torch.cat(
[encoder_hidden_states, encoder_hidden_states], dim=1
)
else:
assert False, mode
res = self.chained_proc(
attn, hidden_states, encoder_hidden_states, attention_mask
)
if self.enabled and is_cfg_guidance:
res = torch.cat([res0, res])
return res
class RefOnlyNoisedUNet(torch.nn.Module):
def __init__(
self,
unet: UNet2DConditionModel,
train_sched: DDPMScheduler,
val_sched: EulerAncestralDiscreteScheduler,
) -> None:
super().__init__()
self.unet = unet
self.train_sched = train_sched
self.val_sched = val_sched
unet_lora_attn_procs = dict()
for name, _ in unet.attn_processors.items():
if torch.__version__ >= "2.0":
default_attn_proc = AttnProcessor2_0()
elif is_xformers_available():
default_attn_proc = XFormersAttnProcessor()
else:
default_attn_proc = AttnProcessor()
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
)
unet.set_attn_processor(unet_lora_attn_procs)
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.unet, name)
def forward_cond(
self,
noisy_cond_lat,
timestep,
encoder_hidden_states,
class_labels,
ref_dict,
is_cfg_guidance,
**kwargs,
):
if is_cfg_guidance:
encoder_hidden_states = encoder_hidden_states[1:]
class_labels = class_labels[1:]
self.unet(
noisy_cond_lat,
timestep,
encoder_hidden_states=encoder_hidden_states,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
**kwargs,
)
def forward(
self,
sample,
timestep,
encoder_hidden_states,
class_labels=None,
*args,
cross_attention_kwargs,
down_block_res_samples=None,
mid_block_res_sample=None,
forward_cond_state=True,
**kwargs,
):
cond_lat = cross_attention_kwargs["cond_lat"]
is_cfg_guidance = cross_attention_kwargs.get("is_cfg_guidance", False)
noise = torch.randn_like(cond_lat)
if self.training:
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
noisy_cond_lat = self.train_sched.scale_model_input(
noisy_cond_lat, timestep
)
else:
noisy_cond_lat = self.val_sched.add_noise(
cond_lat, noise, timestep.reshape(-1)
)
noisy_cond_lat = self.val_sched.scale_model_input(
noisy_cond_lat, timestep.reshape(-1)
)
ref_dict = {}
if "dont_forward_cond_state" not in cross_attention_kwargs.keys():
self.forward_cond(
noisy_cond_lat,
timestep,
encoder_hidden_states,
class_labels,
ref_dict,
is_cfg_guidance,
**kwargs,
)
mode = "r"
else:
mode = "c"
weight_dtype = self.unet.dtype
return self.unet(
sample,
timestep,
encoder_hidden_states,
*args,
class_labels=class_labels,
cross_attention_kwargs=dict(
mode=mode, ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance
),
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
]
if down_block_res_samples is not None
else None,
mid_block_additional_residual=(
mid_block_res_sample.to(dtype=weight_dtype)
if mid_block_res_sample is not None
else None
),
**kwargs,
)
def scale_latents(latents):
latents = (latents - 0.22) * 0.75
return latents
def unscale_latents(latents):
latents = latents / 0.75 + 0.22
return latents
def scale_image(image):
image = image * 0.5 / 0.8
return image
def unscale_image(image):
image = image / 0.5 * 0.8
return image
class DepthControlUNet(torch.nn.Module):
def __init__(
self,
unet: RefOnlyNoisedUNet,
controlnet: Optional[diffusers.ControlNetModel] = None,
conditioning_scale=1.0,
) -> None:
super().__init__()
self.unet = unet
if controlnet is None:
self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)
else:
self.controlnet = controlnet
DefaultAttnProc = AttnProcessor2_0
if is_xformers_available():
DefaultAttnProc = XFormersAttnProcessor
self.controlnet.set_attn_processor(DefaultAttnProc())
self.conditioning_scale = conditioning_scale
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.unet, name)
def forward(
self,
sample,
timestep,
encoder_hidden_states,
class_labels=None,
*args,
cross_attention_kwargs: dict,
**kwargs,
):
cross_attention_kwargs = dict(cross_attention_kwargs)
control_depth = cross_attention_kwargs.pop("control_depth")
down_block_res_samples, mid_block_res_sample = self.controlnet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_depth,
conditioning_scale=self.conditioning_scale,
return_dict=False,
)
return self.unet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_res_samples=down_block_res_samples,
mid_block_res_sample=mid_block_res_sample,
cross_attention_kwargs=cross_attention_kwargs,
)
class ModuleListDict(torch.nn.Module):
def __init__(self, procs: dict) -> None:
super().__init__()
self.keys = sorted(procs.keys())
self.values = torch.nn.ModuleList(procs[k] for k in self.keys)
def __getitem__(self, key):
return self.values[self.keys.index(key)]
class SuperNet(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
num = int(key.split(".")[1]) # 0 is always "layers"
new_key = key.replace(f"layers.{num}", module.mapping[num])
new_state_dict[new_key] = value
return new_state_dict
def remap_key(key, state_dict):
for k in self.split_keys:
if k in key:
return key.split(k)[0] + k
return key.split(".")[0]
def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys())
for key in all_keys:
replace_key = remap_key(key, state_dict)
new_key = key.replace(
replace_key, f"layers.{module.rev_mapping[replace_key]}"
)
state_dict[new_key] = state_dict[key]
del state_dict[key]
self._register_state_dict_hook(map_to)
self._register_load_state_dict_pre_hook(map_from, with_module=True)
class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
tokenizer: transformers.CLIPTokenizer
text_encoder: transformers.CLIPTextModel
vision_encoder: transformers.CLIPVisionModelWithProjection
feature_extractor_clip: transformers.CLIPImageProcessor
unet: UNet2DConditionModel
scheduler: diffusers.schedulers.KarrasDiffusionSchedulers
vae: AutoencoderKL
ramping: nn.Linear
feature_extractor_vae: transformers.CLIPImageProcessor
depth_transforms_multi = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
)
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
vision_encoder: transformers.CLIPVisionModelWithProjection,
feature_extractor_clip: CLIPImageProcessor,
feature_extractor_vae: CLIPImageProcessor,
ramping_coefficients: Optional[list] = None,
safety_checker=None,
):
DiffusionPipeline.__init__(self)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
vision_encoder=vision_encoder,
feature_extractor_clip=feature_extractor_clip,
feature_extractor_vae=feature_extractor_vae,
)
self.register_to_config(ramping_coefficients=ramping_coefficients)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def prepare(self):
train_sched = DDPMScheduler.from_config(self.scheduler.config)
if isinstance(self.unet, UNet2DConditionModel):
self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()
def add_controlnet(
self,
controlnet: Optional[diffusers.ControlNetModel] = None,
conditioning_scale=1.0,
):
self.prepare()
self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
return SuperNet(OrderedDict([("controlnet", self.unet.controlnet)]))
def encode_condition_image(self, image: torch.Tensor):
image = self.vae.encode(image).latent_dist.sample()
return image
@spaces.GPU(duration=60)
@torch.no_grad()
def edit_latents(
self,
image_guidance: Image.Image,
multiview_source_image: Image.Image = None,
edit_strength: float = 1.0,
prompt="",
*args,
guidance_scale=0.0,
output_type: Optional[str] = "pil",
width=640,
height=960,
num_inference_steps=28,
return_dict=True,
**kwargs,
):
self.prepare()
if image_guidance is None:
raise ValueError(
"Inputting embeddings not supported for this pipeline. Please pass an image."
)
if multiview_source_image is None:
raise ValueError("Multiview source image is required for this pipeline.")
assert not isinstance(image_guidance, torch.Tensor)
assert not isinstance(multiview_source_image, torch.Tensor)
image_guidance = to_rgb_image(image_guidance)
image_source = to_rgb_image(multiview_source_image)
image_guidance_1 = self.feature_extractor_vae(
images=image_guidance, return_tensors="pt"
).pixel_values
image_guidance_2 = self.feature_extractor_clip(
images=image_source, return_tensors="pt"
).pixel_values
image_guidance = image_guidance_1.to(
device=self.vae.device, dtype=self.vae.dtype
)
image_guidance_2 = image_guidance_2.to(
device=self.vae.device, dtype=self.vae.dtype
)
cond_lat = self.encode_condition_image(image_guidance)
# if guidance_scale > 1:
negative_lat = self.encode_condition_image(torch.zeros_like(image_guidance))
cond_lat = torch.cat([negative_lat, cond_lat])
encoded = self.vision_encoder(image_guidance_2, output_hidden_states=False)
global_embeds = encoded.image_embeds
global_embeds = global_embeds.unsqueeze(-2)
if hasattr(self, "encode_prompt"):
encoder_hidden_states = self.encode_prompt(prompt, self.device, 1, False)[0]
else:
encoder_hidden_states = self._encode_prompt(prompt, self.device, 1, False)
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
cak = dict(cond_lat=cond_lat)
mv_image = (
torch.from_numpy(numpy.array(multiview_source_image)).to(self.vae.device)
/ 255.0
)
mv_image = (
mv_image.permute(2, 0, 1)
.to(self.vae.device)
.to(self.vae.dtype)
.unsqueeze(0)
)
latents = (
self.vae.encode(mv_image * 2.0 - 1.0).latent_dist.sample()
* self.vae.config.scaling_factor
)
latents: torch.Tensor = (
super()
.__call__(
None,
*args,
cross_attention_kwargs=cak,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
prompt_embeds=encoder_hidden_states,
num_inference_steps=num_inference_steps,
output_type="latent",
width=width,
height=height,
latents=latents,
edit_strength=edit_strength,
**kwargs,
)
.images
)
latents = unscale_latents(latents)
if not output_type == "latent":
image = unscale_image(
self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
)
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@torch.no_grad()
def encode_target_images(self, images):
dtype = next(self.vae.parameters()).dtype
# equals to scaling images to [-1, 1] first and then call scale_image
images = (images - 0.5) / 0.8 # [-0.625, 0.625]
posterior = self.vae.encode(images.to(dtype)).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
latents = scale_latents(latents)
return latents
@spaces.GPU(duration=60)
@torch.no_grad()
def sdedit(
self,
image,
*args,
cond_image: Image.Image = None,
output_type: Optional[str] = "pil",
width=640,
height=960,
num_inference_steps=75,
edit_strength=1.0,
return_dict=True,
guidance_scale=0.0,
**kwargs,
):
self.prepare()
if image is None:
raise ValueError(
"Inputting embeddings not supported for this pipeline. Please pass an image."
)
assert not isinstance(image, torch.Tensor)
image = to_rgb_image(image)
# cond_lat = self.encode_condition_image(image_guidance)
if hasattr(self, "encode_prompt"):
encoder_hidden_states = self.encode_prompt([""], self.device, 1, False)[0]
else:
encoder_hidden_states = self._encode_prompt([""], self.device, 1, False)
# negative_lat = self.encode_condition_image(torch.zeros_like(image_guidance))
# cond_lat = torch.cat([negative_lat, cond_lat])
# encoded = self.vision_encoder(image_guidance_2, output_hidden_states=False)
# global_embeds = encoded.image_embeds
# global_embeds = global_embeds.unsqueeze(-2)
# prompt = ""
# ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
# encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
# cak = dict(cond_lat=cond_lat)
image = torch.from_numpy(numpy.array(image)).to(self.vae.device) / 255.0
image = image.permute(2, 0, 1).unsqueeze(0)
if self.vae.dtype == torch.float16:
image = image.half()
# image = image.permute(2, 0, 1).to(self.vae.device).to(self.vae.dtype).unsqueeze(0)
latents = self.encode_target_images(image)
if cond_image is not None:
cond_image = to_rgb_image(cond_image)
cond_image = (
torch.from_numpy(numpy.array(cond_image)).to(self.vae.device) / 255.0
)
cond_image = cond_image.permute(2, 0, 1).unsqueeze(0)
if self.vae.dtype == torch.float16:
cond_image = cond_image.half()
cond_lat = self.encode_condition_image(cond_image)
else:
cond_lat = self.encode_condition_image(torch.zeros_like(image)).to(
self.vae.device
)
cak = dict(cond_lat=cond_lat, dont_forward_cond_state=True)
latents = self.forward_sdedit(
latents,
cross_attention_kwargs=cak,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
prompt_embeds=encoder_hidden_states,
num_inference_steps=num_inference_steps,
output_type="latent",
width=width,
height=height,
edit_strength=edit_strength,
**kwargs,
).images
# latents = unscale_latents(latents)
if not output_type == "latent":
image = unscale_image(
self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
)
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@spaces.GPU(duration=60)
@torch.no_grad()
def refine(
self,
image: Image.Image = None,
edit_image: Image.Image = None,
prompt: Optional[str] = "",
*args,
output_type: Optional[str] = "pil",
width=640,
height=960,
num_inference_steps=28,
edit_strength=1.0,
return_dict=True,
guidance_scale=4.0,
**kwargs,
):
self.prepare()
if image is None:
raise ValueError(
"Inputting embeddings not supported for this pipeline. Please pass an image."
)
assert not isinstance(image, torch.Tensor)
image = to_rgb_image(image)
# cond_lat = self.encode_condition_image(image_guidance)
if hasattr(self, "encode_prompt"):
encoder_hidden_states = self.encode_prompt(prompt, self.device, 1, False)[0]
else:
encoder_hidden_states = self._encode_prompt(prompt, self.device, 1, False)
# negative_lat = self.encode_condition_image(torch.zeros_like(image_guidance))
# cond_lat = torch.cat([negative_lat, cond_lat])
# encoded = self.vision_encoder(image_guidance_2, output_hidden_states=False)
# global_embeds = encoded.image_embeds
# global_embeds = global_embeds.unsqueeze(-2)
# prompt = ""
# ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
# encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
# cak = dict(cond_lat=cond_lat)
latents_edit = None
if edit_image is not None:
edit_image = to_rgb_image(edit_image)
edit_image = (
torch.from_numpy(numpy.array(edit_image)).to(self.vae.device) / 255.0
)
edit_image = edit_image.permute(2, 0, 1).unsqueeze(0)
if self.vae.dtype == torch.float16:
edit_image = edit_image.half()
latents_edit = self.encode_target_images(edit_image)
image = torch.from_numpy(numpy.array(image)).to(self.vae.device) / 255.0
image = image.permute(2, 0, 1).unsqueeze(0)
if self.vae.dtype == torch.float16:
image = image.half()
# image = torch.nn.functional.interpolate(
# image, (height*4, width*4), mode="bilinear", align_corners=False)
# image = image[...,:320,:320]
height, width = image.shape[-2:]
# image = image[...,:640,:]
# image[...,:320,:] = torch.ones_like(image[...,:320,:])
# image = image.permute(2, 0, 1).to(self.vae.device).to(self.vae.dtype).unsqueeze(0)
# height = height * 4
# width = width * 4
latents = self.encode_target_images(image)
# latents[...,-40:,:] = torch.randn_like(latents[...,-40:,:])
cond_lat = self.encode_condition_image(torch.zeros_like(image)).to(
self.vae.device
)
cak = dict(cond_lat=cond_lat, dont_forward_cond_state=True)
latents = self.forward_pipeline(
latents_edit,
latents,
cross_attention_kwargs=cak,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
prompt_embeds=encoder_hidden_states,
num_inference_steps=num_inference_steps,
output_type="latent",
width=width,
height=height,
edit_strength=edit_strength,
**kwargs,
).images
# latents = unscale_latents(latents)
if not output_type == "latent":
image = unscale_image(
self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
)
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
timestep=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
else:
if timestep is None:
raise ValueError(
"When passing `latents` you also need to pass `timestep`."
)
latents = latents.to(device)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
latents = self.scheduler.add_noise(latents, noise, timestep)
return latents
@torch.no_grad()
def forward_sdedit(
self,
latents: torch.Tensor,
cross_attention_kwargs: dict,
guidance_scale: float,
num_images_per_prompt: int,
prompt_embeds,
num_inference_steps: int,
output_type: str,
width: int,
height: int,
edit_strength: float = 1.0,
):
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
batch_size = prompt_embeds.shape[0]
generator = torch.Generator(device=latents.device)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None)
if cross_attention_kwargs is not None
else None
)
prompt_embeds = self._encode_prompt(
None,
device,
num_images_per_prompt,
do_classifier_free_guidance,
None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=None,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
# self.scheduler.timesteps = self.scheduler.timesteps
timesteps = self.scheduler.timesteps
timesteps = reversed(reversed(timesteps)[: int(edit_strength * len(timesteps))])
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
timesteps[0:1],
)
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
# if do_classifier_free_guidance:
# cond_latent = cond_latent.expand(batch_size * 2, -1, -1, -1)
# 7. Denoising loop
num_warmup_steps = 0
with self.progress_bar(total=len(timesteps)) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# latent_model_input =
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# exit(0)/
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
latents = unscale_latents(latents)
if not output_type == "latent":
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
else:
image = latents
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)
@torch.no_grad()
def forward_pipeline(
self,
latents: torch.Tensor,
cond_latent: torch.Tensor,
cross_attention_kwargs: dict,
guidance_scale: float,
num_images_per_prompt: int,
prompt_embeds,
num_inference_steps: int,
output_type: str,
width: int,
height: int,
edit_strength: float = 1.0,
):
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
batch_size = 1
generator = torch.Generator(device=cond_latent.device)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None)
if cross_attention_kwargs is not None
else None
)
prompt_embeds = self._encode_prompt(
None,
device,
num_images_per_prompt,
do_classifier_free_guidance,
None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=None,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
# self.scheduler.timesteps = self.scheduler.timesteps
timesteps = self.scheduler.timesteps
timesteps = reversed(reversed(timesteps)[: int(edit_strength * len(timesteps))])
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels // 2
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
timesteps[0:1],
)
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
if do_classifier_free_guidance:
cond_latent = cond_latent.expand(batch_size * 2, -1, -1, -1)
# 7. Denoising loop
num_warmup_steps = 0
with self.progress_bar(total=len(timesteps)) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = torch.cat([latent_model_input, cond_latent], dim=1)
# latent_model_input = latent_model_input.half()
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
latents = unscale_latents(latents)
if not output_type == "latent":
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)
@spaces.GPU(duration=60)
@torch.no_grad()
def __call__(
self,
image: Image.Image = None,
source_image: Image.Image = None,
prompt="",
*args,
num_images_per_prompt: Optional[int] = 1,
guidance_scale=4.0,
depth_image: Image.Image = None,
output_type: Optional[str] = "pil",
width=640,
height=960,
num_inference_steps=28,
return_dict=True,
**kwargs,
):
self.prepare()
if image is None:
raise ValueError(
"Inputting embeddings not supported for this pipeline. Please pass an image."
)
assert not isinstance(image, torch.Tensor)
image = to_rgb_image(image)
image_1 = self.feature_extractor_vae(
images=image, return_tensors="pt"
).pixel_values
image_2 = self.feature_extractor_clip(
images=image, return_tensors="pt"
).pixel_values
# image_source = to_rgb_image(source_image)
# image_source_latents = self.feature_extractor_vae(images=image_source, return_tensors="pt")
if depth_image is not None and hasattr(self.unet, "controlnet"):
depth_image = to_rgb_image(depth_image)
depth_image = self.depth_transforms_multi(depth_image).to(
device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
)
image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
cond_lat = self.encode_condition_image(image)
if guidance_scale > 1:
negative_lat = self.encode_condition_image(torch.zeros_like(image))
cond_lat = torch.cat([negative_lat, cond_lat])
encoded = self.vision_encoder(image_2, output_hidden_states=False)
global_embeds = encoded.image_embeds
global_embeds = global_embeds.unsqueeze(-2)
if hasattr(self, "encode_prompt"):
encoder_hidden_states = self.encode_prompt(
prompt, self.device, num_images_per_prompt, False
)[0]
else:
encoder_hidden_states = self._encode_prompt(
prompt, self.device, num_images_per_prompt, False
)
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
cak = dict(cond_lat=cond_lat)
if hasattr(self.unet, "controlnet"):
cak["control_depth"] = depth_image
latents: torch.Tensor = (
super()
.__call__(
None,
*args,
cross_attention_kwargs=cak,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=encoder_hidden_states,
num_inference_steps=num_inference_steps,
output_type="latent",
width=width,
height=height,
latents=None,
**kwargs,
)
.images
)
latents = unscale_latents(latents)
if not output_type == "latent":
image = unscale_image(
self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
)
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)