import random from typing import Callable, Dict, List, Optional import torch from tqdm import tqdm from diffusers import DiffusionPipeline from diffusers.configuration_utils import ConfigMixin def get_scaled_coeffs(): beta_min = 0.85 beta_max = 12.0 return beta_min**0.5, beta_max**0.5-beta_min**0.5 def beta(t): a, b = get_scaled_coeffs() return (a+t*b)**2 def int_beta(t): a, b = get_scaled_coeffs() return ((a+b*t)**3-a**3)/(3*b) def sigma(t): return torch.expm1(int_beta(t))**0.5 def sigma_orig(t): return (-torch.expm1(-int_beta(t)))**0.5 class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin): """SuperDiffSDXLPipeline.""" def __init__(self, unet: Callable, vae: Callable, text_encoder: Callable, text_encoder_2: Callable, tokenizer: Callable, tokenizer_2: Callable) -> None: """__init__. Parameters ---------- model : Callable model vae : Callable vae text_encoder : Callable text_encoder scheduler : Callable scheduler tokenizer : Callable tokenizer kwargs : kwargs Returns ------- None """ super().__init__() device = "cuda" if torch.cuda.is_available() else "cpu" vae.to(device) unet.to(device) text_encoder.to(device) text_encoder_2.to(device) self.register_modules(unet=unet, vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, ) def prepare_prompt_input(self, prompt_o, prompt_b, batch_size, height, width): text_input = self.tokenizer(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") text_input_2 = self.tokenizer_2(prompt_o* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt") with torch.no_grad(): text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True) text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True) prompt_embeds_o = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1) pooled_prompt_embeds_o = text_embeddings_2[0] negative_prompt_embeds = torch.zeros_like(prompt_embeds_o) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds_o) text_input = self.tokenizer(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") text_input_2 = self.tokenizer_2(prompt_b* batch_size, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt") with torch.no_grad(): text_embeddings = self.text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True) text_embeddings_2 = self.text_encoder_2(text_input_2.input_ids.to(self.device), output_hidden_states=True) prompt_embeds_b = torch.concat((text_embeddings.hidden_states[-2], text_embeddings_2.hidden_states[-2]), dim=-1) pooled_prompt_embeds_b = text_embeddings_2[0] add_time_ids_o = torch.tensor([(height,width,0,0,height,width)]) add_time_ids_b = torch.tensor([(height,width,0,0,height,width)]) negative_add_time_ids = torch.tensor([(height,width,0,0,height,width)]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_o, prompt_embeds_b], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_o, pooled_prompt_embeds_b], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids_o, add_time_ids_b], dim=0) prompt_embeds = prompt_embeds.to(self.device) add_text_embeds = add_text_embeds.to(self.device) add_time_ids = add_time_ids.to(self.device).repeat(batch_size, 1) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} return prompt_embeds, added_cond_kwargs @torch.no_grad def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable: """get_batch. Parameters ---------- latents : Callable latents nrow : int nrow ncol : int ncol Returns ------- Callable """ image = self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False )[0] image = (image / 2 + 0.5).clamp(0, 1).squeeze() if len(image.shape) < 4: image = image.unsqueeze(0) image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8) return image @torch.no_grad def get_text_embedding(self, prompt: str) -> Callable: """get_text_embedding. Parameters ---------- prompt : str prompt Returns ------- Callable """ text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) return self.text_encoder(text_input.input_ids.to(self.device))[0] @torch.no_grad def get_vel(self, t: float, sigma: float, latents: Callable, embeddings: Callable): """get_vel. Parameters ---------- t : float t sigma : float sigma latents : Callable latents embeddings : Callable embeddings """ def v(_x, _e): return self.model( _x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e ).sample embeds = torch.cat(embeddings) latent_input = latents vel = v(latent_input, embeds) return vel def preprocess( self, prompt_1: str, prompt_2: str, seed: int = None, num_inference_steps: int = 1000, batch_size: int = 1, lift: int = 0.0, height: int = 512, width: int = 512, guidance_scale: int = 7.5, ) -> Callable: """preprocess. Parameters ---------- prompt_1 : str prompt_1 prompt_2 : str prompt_2 seed : int seed num_inference_steps : int num_inference_steps batch_size : int batch_size lift : int lift height : int height width : int width guidance_scale : int guidance_scale Returns ------- Callable """ # Tokenize the input self.batch_size = batch_size self.num_inference_steps = num_inference_steps self.guidance_scale = guidance_scale self.lift = lift self.seed = seed if self.seed is None: self.seed = random.randint(0, 2**32 - 1) #obj_prompt = [prompt_1] #bg_prompt = [prompt_2] #obj_embeddings = self.get_text_embedding(obj_prompt * batch_size) #bg_embeddings = self.get_text_embedding(bg_prompt * batch_size) #uncond_embeddings = self.get_text_embedding([""] * batch_size) generator = torch.cuda.manual_seed( self.seed ) # Seed generator to create the initial latent noise latents = torch.randn((batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, dtype=self.dtype, device=self.device,) prompt_embeds, added_cond_kwargs = self.prepare_prompt_input(prompt_1, prompt_2, batch_size, height, width) #latents = torch.randn( # (batch_size, self.model.config.in_channels, height // 8, width // 8), # generator=generator, # device=self.device, #) #latents_og = latents.clone().detach() #latents_uncond_og = latents.clone().detach() #self.scheduler.set_timesteps(num_inference_steps) #latents = latents * self.scheduler.init_noise_sigma #latents_uncond = latents.clone().detach() return { "latents": latents, "prompt_embeds": prompt_embeds, "added_cond_kwargs": added_cond_kwargs, } def _forward(self, model_inputs: Dict) -> Callable: """_forward. Parameters ---------- model_inputs : Dict model_inputs Returns ------- Callable """ latents = model_inputs["latents"] prompt_embeds = model_inputs["prompt_embeds"] added_cond_kwargs = model_inputs["added_cond_kwargs"] t = torch.tensor(1.0) dt = 1.0/self.num_inference_steps train_number_steps = 1000 latents = latents * (sigma(t)**2+1)**0.5 with torch.no_grad(): for i in tqdm(range(self.num_inference_steps)): latent_model_input = torch.cat([latents] * 3) sigma_t = sigma(t) dsigma = sigma(t-dt) - sigma_t latent_model_input /= (sigma_t**2+1)**0.5 with torch.no_grad(): noise_pred = self.unet(latent_model_input, t*train_number_steps, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0] noise_pred_uncond, noise_pred_text_o, noise_pred_text_b = noise_pred.chunk(3) noise = torch.sqrt(2*torch.abs(dsigma)*sigma_t)*torch.randn_like(latents) dx_ind = 2*dsigma*(noise_pred_uncond + self.guidance_scale*(noise_pred_text_b - noise_pred_uncond)) + noise kappa = (torch.abs(dsigma)*(noise_pred_text_b-noise_pred_text_o)*(noise_pred_text_b+noise_pred_text_o)).sum((1,2,3))-(dx_ind*((noise_pred_text_o-noise_pred_text_b))).sum((1,2,3)) kappa /= 2*dsigma*self.guidance_scale*((noise_pred_text_o-noise_pred_text_b)**2).sum((1,2,3)) noise_pred = noise_pred_uncond + self.guidance_scale*((noise_pred_text_b - noise_pred_uncond) + kappa[:,None,None,None]*(noise_pred_text_o-noise_pred_text_b)) latents += 2*dsigma * noise_pred + noise t -= dt return latents def postprocess(self, latents: Callable) -> Callable: """postprocess. Parameters ---------- latents : Callable latents Returns ------- Callable """ latents = latents/self.vae.config.scaling_factor latents = latents.to(torch.float32) with torch.no_grad(): image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") return images def __call__( self, prompt_1: str, prompt_2: str, seed: int = None, num_inference_steps: int = 1000, batch_size: int = 1, lift: int = 0.0, height: int = 512, width: int = 512, guidance_scale: int = 7.5, ) -> Callable: """__call__. Parameters ---------- prompt_1 : str prompt_1 prompt_2 : str prompt_2 seed : int seed num_inference_steps : int num_inference_steps batch_size : int batch_size lift : int lift height : int height width : int width guidance_scale : int guidance_scale Returns ------- Callable """ # Preprocess inputs model_inputs = self.preprocess( prompt_1, prompt_2, seed, num_inference_steps, batch_size, lift, height, width, guidance_scale, ) # Forward pass through the pipeline latents = self._forward(model_inputs) # Postprocess to generate the final output images = self.postprocess(latents) return images