superdiff-sdxl-v1-0 / pipeline.py
mskrt's picture
Update pipeline.py
7bfed41 verified
import random
from typing import Callable, Dict
import torch
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from tqdm import tqdm
# from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
# from diffusers import AutoencoderKL, UNet2DConditionModel
def get_scaled_coeffs():
"""get_scaled_coeffs."""
beta_min = 0.85
beta_max = 12.0
return beta_min**0.5, beta_max**0.5 - beta_min**0.5
def beta(t):
"""beta.
Parameters
----------
t :
t
"""
a, b = get_scaled_coeffs()
return (a + t * b) ** 2
def int_beta(t):
"""int_beta.
Parameters
----------
t :
t
"""
a, b = get_scaled_coeffs()
return ((a + b * t) ** 3 - a**3) / (3 * b)
def sigma(t):
"""sigma.
Parameters
----------
t :
t
"""
return torch.expm1(int_beta(t)) ** 0.5
def sigma_orig(t):
"""sigma_orig.
Parameters
----------
t :
t
"""
return (-torch.expm1(-int_beta(t))) ** 0.5
class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
"""SuperDiffSDXLPipeline."""
def __init__(
self,
unet: Callable,
vae: Callable,
text_encoder: Callable,
text_encoder_2: Callable,
tokenizer: Callable,
tokenizer_2: Callable,
) -> None:
"""__init__.
Parameters
----------
model : Callable
model
vae : Callable
vae
text_encoder : Callable
text_encoder
scheduler : Callable
scheduler
tokenizer : Callable
tokenizer
kwargs :
kwargs
Returns
-------
None
"""
super().__init__()
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
vae.to(device)
unet.to(device, dtype=dtype)
text_encoder.to(device, dtype=dtype)
text_encoder_2.to(device, dtype=dtype)
self.register_modules(
unet=unet,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
)
def prepare_prompt_input(self, prompt_o, prompt_b, batch_size, height, width):
"""prepare_prompt_input.
Parameters
----------
prompt_o :
prompt_o
prompt_b :
prompt_b
batch_size :
batch_size
height :
height
width :
width
"""
text_input = self.tokenizer(
prompt_o * batch_size,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_2 = self.tokenizer_2(
prompt_o * batch_size,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
text_embeddings = self.text_encoder(
text_input.input_ids.to(self.device), output_hidden_states=True
)
text_embeddings_2 = self.text_encoder_2(
text_input_2.input_ids.to(self.device), output_hidden_states=True
)
prompt_embeds_o = torch.concat(
(text_embeddings.hidden_states[-2],
text_embeddings_2.hidden_states[-2]),
dim=-1,
)
pooled_prompt_embeds_o = text_embeddings_2[0]
negative_prompt_embeds = torch.zeros_like(prompt_embeds_o)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds_o)
text_input = self.tokenizer(
prompt_b * batch_size,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_2 = self.tokenizer_2(
prompt_b * batch_size,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
text_embeddings = self.text_encoder(
text_input.input_ids.to(self.device), output_hidden_states=True
)
text_embeddings_2 = self.text_encoder_2(
text_input_2.input_ids.to(self.device), output_hidden_states=True
)
prompt_embeds_b = torch.concat(
(text_embeddings.hidden_states[-2],
text_embeddings_2.hidden_states[-2]),
dim=-1,
)
pooled_prompt_embeds_b = text_embeddings_2[0]
add_time_ids_o = torch.tensor([(height, width, 0, 0, height, width)])
add_time_ids_b = torch.tensor([(height, width, 0, 0, height, width)])
negative_add_time_ids = torch.tensor(
[(height, width, 0, 0, height, width)])
prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0
)
add_text_embeds = torch.cat(
[
negative_pooled_prompt_embeds,
pooled_prompt_embeds_o,
pooled_prompt_embeds_b,
],
dim=0,
)
add_time_ids = torch.cat(
[negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0
)
prompt_embeds = prompt_embeds.to(self.device)
add_text_embeds = add_text_embeds.to(self.device)
add_time_ids = add_time_ids.to(self.device).repeat(batch_size, 1)
added_cond_kwargs = {
"text_embeds": add_text_embeds, "time_ids": add_time_ids}
return prompt_embeds, added_cond_kwargs
@torch.no_grad
def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
"""get_batch.
Parameters
----------
latents : Callable
latents
nrow : int
nrow
ncol : int
ncol
Returns
-------
Callable
"""
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
if len(image.shape) < 4:
image = image.unsqueeze(0)
image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8)
return image
@torch.no_grad
def get_text_embedding(self, prompt: str) -> Callable:
"""get_text_embedding.
Parameters
----------
prompt : str
prompt
Returns
-------
Callable
"""
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
return self.text_encoder(text_input.input_ids.to(self.device))[0]
@torch.no_grad
def get_vel(self, t: float, sigma: float, latents: Callable, embeddings: Callable):
"""get_vel.
Parameters
----------
t : float
t
sigma : float
sigma
latents : Callable
latents
embeddings : Callable
embeddings
"""
def v(_x, _e):
"""v.
Parameters
----------
_x :
_x
_e :
_e
"""
return self.model(
_x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e
).sample
embeds = torch.cat(embeddings)
latent_input = latents
vel = v(latent_input, embeds)
return vel
def preprocess(
self,
prompt_1: str,
prompt_2: str,
seed: int = None,
num_inference_steps: int = 200,
batch_size: int = 1,
height: int = 1024,
width: int = 1024,
guidance_scale: float = 7.5,
) -> Callable:
"""preprocess.
Parameters
----------
prompt_1 : str
prompt_1
prompt_2 : str
prompt_2
seed : int
seed
num_inference_steps : int
num_inference_steps
batch_size : int
batch_size
height : int
height
width : int
width
guidance_scale : float
guidance_scale
Returns
-------
Callable
"""
# Tokenize the input
self.batch_size = batch_size
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.seed = seed
if self.seed is None:
self.seed = random.randint(0, 2**32 - 1)
self.generator = torch.cuda.manual_seed(
self.seed
) # Seed generator to create the initial latent noise
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=self.generator,
dtype=torch.float16,
device=self.device,
)
prompt_embeds, added_cond_kwargs = self.prepare_prompt_input(
prompt_1, prompt_2, batch_size, height, width
)
return {
"latents": latents,
"prompt_embeds": prompt_embeds,
"added_cond_kwargs": added_cond_kwargs,
}
def _forward(self, model_inputs: Dict) -> Callable:
"""_forward.
Parameters
----------
model_inputs : Dict
model_inputs
Returns
-------
Callable
"""
latents = model_inputs["latents"]
prompt_embeds = model_inputs["prompt_embeds"]
added_cond_kwargs = model_inputs["added_cond_kwargs"]
t = torch.tensor(1.0)
dt = 1.0 / self.num_inference_steps
train_number_steps = 1000
latents = latents * (sigma(t) ** 2 + 1) ** 0.5
with torch.no_grad():
for i in tqdm(range(self.num_inference_steps)):
latent_model_input = torch.cat([latents] * 3)
sigma_t = sigma(t)
dsigma = sigma(t - dt) - sigma_t
latent_model_input /= (sigma_t**2 + 1) ** 0.5
with torch.no_grad():
noise_pred = self.unet(
latent_model_input,
t * train_number_steps,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
(
noise_pred_uncond,
noise_pred_text_o,
noise_pred_text_b,
) = noise_pred.chunk(3)
# noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents)
noise = torch.sqrt(2 * torch.abs(dsigma) * sigma_t) * torch.empty_like(
latents, device=self.device
).normal_(generator=self.generator)
dx_ind = (
2
* dsigma
* (
noise_pred_uncond
+ self.guidance_scale *
(noise_pred_text_b - noise_pred_uncond)
)
+ noise
)
kappa = (
torch.abs(dsigma)
* (noise_pred_text_b - noise_pred_text_o)
* (noise_pred_text_b + noise_pred_text_o)
).sum((1, 2, 3)) - (
dx_ind * ((noise_pred_text_o - noise_pred_text_b))
).sum(
(1, 2, 3)
)
kappa /= (
2
* dsigma
* self.guidance_scale
* ((noise_pred_text_o - noise_pred_text_b) ** 2).sum((1, 2, 3))
)
noise_pred = noise_pred_uncond + self.guidance_scale * (
(noise_pred_text_b - noise_pred_uncond)
+ kappa[:, None, None, None]
* (noise_pred_text_o - noise_pred_text_b)
)
if i < self.num_inference_steps - 3:
latents += 2 * dsigma * noise_pred + noise
else:
latents += dsigma * noise_pred
t -= dt
return latents
def postprocess(self, latents: Callable) -> Callable:
"""postprocess.
Parameters
----------
latents : Callable
latents
Returns
-------
Callable
"""
latents = latents / self.vae.config.scaling_factor
latents = latents.to(torch.float32)
with torch.no_grad():
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
return images
def __call__(
self,
prompt_1: str,
prompt_2: str,
seed: int = None,
num_inference_steps: int = 200,
batch_size: int = 1,
height: int = 1024,
width: int = 1024,
guidance_scale: float = 7.5,
) -> Callable:
"""__call__.
Parameters
----------
prompt_1 : str
prompt_1
prompt_2 : str
prompt_2
seed : int
seed
num_inference_steps : int
num_inference_steps
batch_size : int
batch_size
height : int
height
width : int
width
guidance_scale : float
guidance_scale
Returns
-------
Callable
"""
# Preprocess inputs
model_inputs = self.preprocess(
prompt_1,
prompt_2,
seed,
num_inference_steps,
batch_size,
height,
width,
guidance_scale,
)
# Forward pass through the pipeline
latents = self._forward(model_inputs)
# Postprocess to generate the final output
images = self.postprocess(latents)
return images