import random from typing import Callable, Dict import torch from diffusers import DiffusionPipeline from diffusers.configuration_utils import ConfigMixin from tqdm import tqdm # from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer # from diffusers import AutoencoderKL, UNet2DConditionModel def get_scaled_coeffs(): """get_scaled_coeffs.""" beta_min = 0.85 beta_max = 12.0 return beta_min**0.5, beta_max**0.5 - beta_min**0.5 def beta(t): """beta. Parameters ---------- t : t """ a, b = get_scaled_coeffs() return (a + t * b) ** 2 def int_beta(t): """int_beta. Parameters ---------- t : t """ a, b = get_scaled_coeffs() return ((a + b * t) ** 3 - a**3) / (3 * b) def sigma(t): """sigma. Parameters ---------- t : t """ return torch.expm1(int_beta(t)) ** 0.5 def sigma_orig(t): """sigma_orig. Parameters ---------- t : t """ return (-torch.expm1(-int_beta(t))) ** 0.5 class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin): """SuperDiffSDXLPipeline.""" def __init__( self, unet: Callable, vae: Callable, text_encoder: Callable, text_encoder_2: Callable, tokenizer: Callable, tokenizer_2: Callable, ) -> None: """__init__. Parameters ---------- model : Callable model vae : Callable vae text_encoder : Callable text_encoder scheduler : Callable scheduler tokenizer : Callable tokenizer kwargs : kwargs Returns ------- None """ super().__init__() device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 vae.to(device) unet.to(device, dtype=dtype) text_encoder.to(device, dtype=dtype) text_encoder_2.to(device, dtype=dtype) self.register_modules( unet=unet, vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, ) def prepare_prompt_input(self, prompt_o, prompt_b, batch_size, height, width): """prepare_prompt_input. Parameters ---------- prompt_o : prompt_o prompt_b : prompt_b batch_size : batch_size height : height width : width """ text_input = self.tokenizer( prompt_o * batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_2 = self.tokenizer_2( prompt_o * batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) with torch.no_grad(): text_embeddings = self.text_encoder( text_input.input_ids.to(self.device), output_hidden_states=True ) text_embeddings_2 = self.text_encoder_2( text_input_2.input_ids.to(self.device), output_hidden_states=True ) prompt_embeds_o = torch.concat( (text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1, ) pooled_prompt_embeds_o = text_embeddings_2[0] negative_prompt_embeds = torch.zeros_like(prompt_embeds_o) negative_pooled_prompt_embeds = torch.zeros_like( pooled_prompt_embeds_o) text_input = self.tokenizer( prompt_b * batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_2 = self.tokenizer_2( prompt_b * batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) with torch.no_grad(): text_embeddings = self.text_encoder( text_input.input_ids.to(self.device), output_hidden_states=True ) text_embeddings_2 = self.text_encoder_2( text_input_2.input_ids.to(self.device), output_hidden_states=True ) prompt_embeds_b = torch.concat( (text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1, ) pooled_prompt_embeds_b = text_embeddings_2[0] add_time_ids_o = torch.tensor([(height, width, 0, 0, height, width)]) add_time_ids_b = torch.tensor([(height, width, 0, 0, height, width)]) negative_add_time_ids = torch.tensor( [(height, width, 0, 0, height, width)]) prompt_embeds = torch.cat( [negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0 ) add_text_embeds = torch.cat( [ negative_pooled_prompt_embeds, pooled_prompt_embeds_o, pooled_prompt_embeds_b, ], dim=0, ) add_time_ids = torch.cat( [negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0 ) prompt_embeds = prompt_embeds.to(self.device) add_text_embeds = add_text_embeds.to(self.device) add_time_ids = add_time_ids.to(self.device).repeat(batch_size, 1) added_cond_kwargs = { "text_embeds": add_text_embeds, "time_ids": add_time_ids} return prompt_embeds, added_cond_kwargs @torch.no_grad def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable: """get_batch. Parameters ---------- latents : Callable latents nrow : int nrow ncol : int ncol Returns ------- Callable """ image = self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False )[0] image = (image / 2 + 0.5).clamp(0, 1).squeeze() if len(image.shape) < 4: image = image.unsqueeze(0) image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8) return image @torch.no_grad def get_text_embedding(self, prompt: str) -> Callable: """get_text_embedding. Parameters ---------- prompt : str prompt Returns ------- Callable """ text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) return self.text_encoder(text_input.input_ids.to(self.device))[0] @torch.no_grad def get_vel(self, t: float, sigma: float, latents: Callable, embeddings: Callable): """get_vel. Parameters ---------- t : float t sigma : float sigma latents : Callable latents embeddings : Callable embeddings """ def v(_x, _e): """v. Parameters ---------- _x : _x _e : _e """ return self.model( _x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e ).sample embeds = torch.cat(embeddings) latent_input = latents vel = v(latent_input, embeds) return vel def preprocess( self, prompt_1: str, prompt_2: str, seed: int = None, num_inference_steps: int = 200, batch_size: int = 1, height: int = 1024, width: int = 1024, guidance_scale: float = 7.5, ) -> Callable: """preprocess. Parameters ---------- prompt_1 : str prompt_1 prompt_2 : str prompt_2 seed : int seed num_inference_steps : int num_inference_steps batch_size : int batch_size height : int height width : int width guidance_scale : float guidance_scale Returns ------- Callable """ # Tokenize the input self.batch_size = batch_size self.num_inference_steps = num_inference_steps self.guidance_scale = guidance_scale self.seed = seed if self.seed is None: self.seed = random.randint(0, 2**32 - 1) self.generator = torch.cuda.manual_seed( self.seed ) # Seed generator to create the initial latent noise latents = torch.randn( (batch_size, self.unet.in_channels, height // 8, width // 8), generator=self.generator, dtype=torch.float16, device=self.device, ) prompt_embeds, added_cond_kwargs = self.prepare_prompt_input( prompt_1, prompt_2, batch_size, height, width ) return { "latents": latents, "prompt_embeds": prompt_embeds, "added_cond_kwargs": added_cond_kwargs, } def _forward(self, model_inputs: Dict) -> Callable: """_forward. Parameters ---------- model_inputs : Dict model_inputs Returns ------- Callable """ latents = model_inputs["latents"] prompt_embeds = model_inputs["prompt_embeds"] added_cond_kwargs = model_inputs["added_cond_kwargs"] t = torch.tensor(1.0) dt = 1.0 / self.num_inference_steps train_number_steps = 1000 latents = latents * (sigma(t) ** 2 + 1) ** 0.5 with torch.no_grad(): for i in tqdm(range(self.num_inference_steps)): latent_model_input = torch.cat([latents] * 3) sigma_t = sigma(t) dsigma = sigma(t - dt) - sigma_t latent_model_input /= (sigma_t**2 + 1) ** 0.5 with torch.no_grad(): noise_pred = self.unet( latent_model_input, t * train_number_steps, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] ( noise_pred_uncond, noise_pred_text_o, noise_pred_text_b, ) = noise_pred.chunk(3) # noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents) noise = torch.sqrt(2 * torch.abs(dsigma) * sigma_t) * torch.empty_like( latents, device=self.device ).normal_(generator=self.generator) dx_ind = ( 2 * dsigma * ( noise_pred_uncond + self.guidance_scale * (noise_pred_text_b - noise_pred_uncond) ) + noise ) kappa = ( torch.abs(dsigma) * (noise_pred_text_b - noise_pred_text_o) * (noise_pred_text_b + noise_pred_text_o) ).sum((1, 2, 3)) - ( dx_ind * ((noise_pred_text_o - noise_pred_text_b)) ).sum( (1, 2, 3) ) kappa /= ( 2 * dsigma * self.guidance_scale * ((noise_pred_text_o - noise_pred_text_b) ** 2).sum((1, 2, 3)) ) noise_pred = noise_pred_uncond + self.guidance_scale * ( (noise_pred_text_b - noise_pred_uncond) + kappa[:, None, None, None] * (noise_pred_text_o - noise_pred_text_b) ) if i < self.num_inference_steps - 3: latents += 2 * dsigma * noise_pred + noise else: latents += dsigma * noise_pred t -= dt return latents def postprocess(self, latents: Callable) -> Callable: """postprocess. Parameters ---------- latents : Callable latents Returns ------- Callable """ latents = latents / self.vae.config.scaling_factor latents = latents.to(torch.float32) with torch.no_grad(): image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") return images def __call__( self, prompt_1: str, prompt_2: str, seed: int = None, num_inference_steps: int = 200, batch_size: int = 1, height: int = 1024, width: int = 1024, guidance_scale: float = 7.5, ) -> Callable: """__call__. Parameters ---------- prompt_1 : str prompt_1 prompt_2 : str prompt_2 seed : int seed num_inference_steps : int num_inference_steps batch_size : int batch_size height : int height width : int width guidance_scale : float guidance_scale Returns ------- Callable """ # Preprocess inputs model_inputs = self.preprocess( prompt_1, prompt_2, seed, num_inference_steps, batch_size, height, width, guidance_scale, ) # Forward pass through the pipeline latents = self._forward(model_inputs) # Postprocess to generate the final output images = self.postprocess(latents) return images