Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from typing import List, Optional, Union | |
import torch | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
from diffusers.utils import BaseOutput | |
from diffusers.utils import ( | |
logging, | |
) | |
from diffusers.utils.torch_utils import randn_tensor | |
from dataclasses import dataclass | |
from model.dit import DiT_Llama | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class PiTPipelineOutput(BaseOutput): | |
image_embeds: torch.Tensor | |
class PiTPipeline(DiffusionPipeline): | |
def __init__(self, prior: DiT_Llama): | |
super().__init__() | |
self.register_modules( | |
prior=prior, | |
) | |
def prepare_latents(self, shape, dtype, device, generator, latents): | |
if latents is None: | |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
else: | |
if latents.shape != shape: | |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | |
latents = latents.to(device) | |
return latents | |
def __call__( | |
self, | |
cond_sequence: torch.FloatTensor, | |
negative_cond_sequence: torch.FloatTensor, | |
num_images_per_prompt: int = 1, | |
num_inference_steps: int = 25, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
init_latents: Optional[torch.FloatTensor] = None, | |
strength: Optional[float] = None, | |
guidance_scale: float = 1.0, | |
output_type: Optional[str] = "pt", # pt only | |
return_dict: bool = True, | |
): | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
device = self._execution_device | |
batch_size = cond_sequence.shape[0] | |
batch_size = batch_size * num_images_per_prompt | |
embedding_dim = self.prior.config.embedding_dim | |
latents = self.prepare_latents( | |
(batch_size, 16, embedding_dim), | |
self.prior.dtype, | |
device, | |
generator, | |
latents, | |
) | |
if init_latents is not None: | |
init_latents = init_latents.to(latents.device) | |
latents = (strength) * latents + (1 - strength) * init_latents | |
# Rectified Flow | |
dt = 1.0 / num_inference_steps | |
dt = torch.tensor([dt] * batch_size).to(latents.device).view([batch_size, *([1] * len(latents.shape[1:]))]) | |
start_inference_step = ( | |
math.ceil(num_inference_steps * (strength)) if strength is not None else num_inference_steps | |
) | |
for i in range(start_inference_step, 0, -1): | |
t = i / num_inference_steps | |
t = torch.tensor([t] * batch_size).to(latents.device) | |
vc = self.prior(latents, t, cond_sequence) | |
if do_classifier_free_guidance: | |
vu = self.prior(latents, t, negative_cond_sequence) | |
vc = vu + guidance_scale * (vc - vu) | |
latents = latents - dt * vc | |
image_embeddings = latents | |
if output_type not in ["pt", "np"]: | |
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") | |
if output_type == "np": | |
image_embeddings = image_embeddings.cpu().numpy() | |
if not return_dict: | |
return image_embeddings | |
return PiTPipelineOutput(image_embeds=image_embeddings) | |