Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,460 Bytes
779c9ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import math
from typing import List, Optional, Union
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
from diffusers.utils import (
logging,
)
from diffusers.utils.torch_utils import randn_tensor
from dataclasses import dataclass
from model.dit import DiT_Llama
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class PiTPipelineOutput(BaseOutput):
image_embeds: torch.Tensor
class PiTPipeline(DiffusionPipeline):
def __init__(self, prior: DiT_Llama):
super().__init__()
self.register_modules(
prior=prior,
)
def prepare_latents(self, shape, dtype, device, generator, latents):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
return latents
@torch.no_grad()
def __call__(
self,
cond_sequence: torch.FloatTensor,
negative_cond_sequence: torch.FloatTensor,
num_images_per_prompt: int = 1,
num_inference_steps: int = 25,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
init_latents: Optional[torch.FloatTensor] = None,
strength: Optional[float] = None,
guidance_scale: float = 1.0,
output_type: Optional[str] = "pt", # pt only
return_dict: bool = True,
):
do_classifier_free_guidance = guidance_scale > 1.0
device = self._execution_device
batch_size = cond_sequence.shape[0]
batch_size = batch_size * num_images_per_prompt
embedding_dim = self.prior.config.embedding_dim
latents = self.prepare_latents(
(batch_size, 16, embedding_dim),
self.prior.dtype,
device,
generator,
latents,
)
if init_latents is not None:
init_latents = init_latents.to(latents.device)
latents = (strength) * latents + (1 - strength) * init_latents
# Rectified Flow
dt = 1.0 / num_inference_steps
dt = torch.tensor([dt] * batch_size).to(latents.device).view([batch_size, *([1] * len(latents.shape[1:]))])
start_inference_step = (
math.ceil(num_inference_steps * (strength)) if strength is not None else num_inference_steps
)
for i in range(start_inference_step, 0, -1):
t = i / num_inference_steps
t = torch.tensor([t] * batch_size).to(latents.device)
vc = self.prior(latents, t, cond_sequence)
if do_classifier_free_guidance:
vu = self.prior(latents, t, negative_cond_sequence)
vc = vu + guidance_scale * (vc - vu)
latents = latents - dt * vc
image_embeddings = latents
if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
if output_type == "np":
image_embeddings = image_embeddings.cpu().numpy()
if not return_dict:
return image_embeddings
return PiTPipelineOutput(image_embeds=image_embeddings)
|