lehduong commited on
Commit
d5e50c0
·
verified ·
1 Parent(s): acf05ba

Delete diffusion/pipelines/onediffusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion/pipelines/onediffusion.py +0 -1080
diffusion/pipelines/onediffusion.py DELETED
@@ -1,1080 +0,0 @@
1
- import einops
2
- import inspect
3
- import torch
4
- import numpy as np
5
- import PIL
6
- import os
7
-
8
- from dataclasses import dataclass
9
- from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
10
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
- from diffusers.utils import (
12
- CONFIG_NAME,
13
- DEPRECATED_REVISION_ARGS,
14
- BaseOutput,
15
- PushToHubMixin,
16
- deprecate,
17
- is_accelerate_available,
18
- is_accelerate_version,
19
- is_torch_npu_available,
20
- is_torch_version,
21
- logging,
22
- numpy_to_pil,
23
- replace_example_docstring,
24
- )
25
- from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
26
- from diffusers.utils.torch_utils import randn_tensor
27
- from diffusers.utils import BaseOutput
28
- # from diffusers.image_processor import VaeImageProcessor
29
- from transformers import T5EncoderModel, T5Tokenizer
30
- from typing import Any, Callable, Dict, List, Optional, Union
31
- from PIL import Image
32
-
33
- from onediffusion.models.denoiser.nextdit import NextDiT
34
- from onediffusion.dataset.utils import *
35
- from onediffusion.dataset.multitask.multiview import calculate_rays
36
- from onediffusion.diffusion.pipelines.image_processor import VaeImageProcessorOneDiffuser
37
-
38
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
-
40
- SUPPORTED_DEVICE_MAP = ["balanced"]
41
-
42
- EXAMPLE_DOC_STRING = """
43
- Examples:
44
- ```py
45
- >>> import torch
46
- >>> from one_diffusion import OneDiffusionPipeline
47
-
48
- >>> pipe = OneDiffusionPipeline.from_pretrained("path_to_one_diffuser_model")
49
- >>> pipe = pipe.to("cuda")
50
-
51
- >>> prompt = "A beautiful sunset over the ocean"
52
- >>> image = pipe(prompt).images[0]
53
- >>> image.save("beautiful_sunset.png")
54
- ```
55
- """
56
-
57
- def create_c2w_matrix(azimuth_deg, elevation_deg, distance=1.0, target=np.array([0, 0, 0])):
58
- """
59
- Create a Camera-to-World (C2W) matrix from azimuth and elevation angles.
60
-
61
- Parameters:
62
- - azimuth_deg: Azimuth angle in degrees.
63
- - elevation_deg: Elevation angle in degrees.
64
- - distance: Distance from the target point.
65
- - target: The point the camera is looking at in world coordinates.
66
-
67
- Returns:
68
- - C2W: A 4x4 NumPy array representing the Camera-to-World transformation matrix.
69
- """
70
- # Convert angles from degrees to radians
71
- azimuth = np.deg2rad(azimuth_deg)
72
- elevation = np.deg2rad(elevation_deg)
73
-
74
- # Spherical to Cartesian conversion for camera position
75
- x = distance * np.cos(elevation) * np.cos(azimuth)
76
- y = distance * np.cos(elevation) * np.sin(azimuth)
77
- z = distance * np.sin(elevation)
78
- camera_position = np.array([x, y, z])
79
-
80
- # Define the forward vector (from camera to target)
81
- target = 2*camera_position - target
82
- forward = target - camera_position
83
- forward /= np.linalg.norm(forward)
84
-
85
- # Define the world up vector
86
- world_up = np.array([0, 0, 1])
87
-
88
- # Compute the right vector
89
- right = np.cross(world_up, forward)
90
- if np.linalg.norm(right) < 1e-6:
91
- # Handle the singularity when forward is parallel to world_up
92
- world_up = np.array([0, 1, 0])
93
- right = np.cross(world_up, forward)
94
- right /= np.linalg.norm(right)
95
-
96
- # Recompute the orthogonal up vector
97
- up = np.cross(forward, right)
98
-
99
- # Construct the rotation matrix
100
- rotation = np.vstack([right, up, forward]).T # 3x3
101
-
102
- # Construct the full C2W matrix
103
- C2W = np.eye(4)
104
- C2W[:3, :3] = rotation
105
- C2W[:3, 3] = camera_position
106
-
107
- return C2W
108
-
109
- @dataclass
110
- class OneDiffusionPipelineOutput(BaseOutput):
111
- """
112
- Output class for Stable Diffusion pipelines.
113
-
114
- Args:
115
- images (`List[PIL.Image.Image]` or `np.ndarray`)
116
- List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
117
- num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
118
- """
119
-
120
- images: Union[List[Image.Image], np.ndarray]
121
- latents: Optional[torch.Tensor] = None
122
-
123
-
124
- def retrieve_latents(
125
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
126
- ):
127
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
128
- return encoder_output.latent_dist.sample(generator)
129
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
130
- return encoder_output.latent_dist.mode()
131
- elif hasattr(encoder_output, "latents"):
132
- return encoder_output.latents
133
- else:
134
- raise AttributeError("Could not access latents of provided encoder_output")
135
-
136
-
137
- def calculate_shift(
138
- image_seq_len,
139
- base_seq_len: int = 256,
140
- max_seq_len: int = 4096,
141
- base_shift: float = 0.5,
142
- max_shift: float = 1.16,
143
- # max_clip: float = 1.5,
144
- ):
145
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len) # 0.000169270833
146
- b = base_shift - m * base_seq_len # 0.5-0.0433333332
147
- mu = image_seq_len * m + b
148
- # mu = min(mu, max_clip)
149
- return mu
150
-
151
-
152
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
153
- def retrieve_timesteps(
154
- scheduler,
155
- num_inference_steps: Optional[int] = None,
156
- device: Optional[Union[str, torch.device]] = None,
157
- timesteps: Optional[List[int]] = None,
158
- sigmas: Optional[List[float]] = None,
159
- **kwargs,
160
- ):
161
- """
162
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
163
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
164
-
165
- Args:
166
- scheduler (`SchedulerMixin`):
167
- The scheduler to get timesteps from.
168
- num_inference_steps (`int`):
169
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
170
- must be `None`.
171
- device (`str` or `torch.device`, *optional*):
172
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
173
- timesteps (`List[int]`, *optional*):
174
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
175
- `num_inference_steps` and `sigmas` must be `None`.
176
- sigmas (`List[float]`, *optional*):
177
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
178
- `num_inference_steps` and `timesteps` must be `None`.
179
-
180
- Returns:
181
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
182
- second element is the number of inference steps.
183
- """
184
- if timesteps is not None and sigmas is not None:
185
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
186
- if timesteps is not None:
187
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
188
- if not accepts_timesteps:
189
- raise ValueError(
190
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
191
- f" timestep schedules. Please check whether you are using the correct scheduler."
192
- )
193
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
194
- timesteps = scheduler.timesteps
195
- num_inference_steps = len(timesteps)
196
- elif sigmas is not None:
197
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
198
- if not accept_sigmas:
199
- raise ValueError(
200
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
201
- f" sigmas schedules. Please check whether you are using the correct scheduler."
202
- )
203
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
204
- timesteps = scheduler.timesteps
205
- num_inference_steps = len(timesteps)
206
- else:
207
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
208
- timesteps = scheduler.timesteps
209
- return timesteps, num_inference_steps
210
-
211
-
212
-
213
- class OneDiffusionPipeline(DiffusionPipeline):
214
- r"""
215
- Pipeline for text-to-image generation using OneDiffuser.
216
-
217
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
218
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
219
-
220
- Args:
221
- transformer ([`NextDiT`]):
222
- Conditional transformer (NextDiT) architecture to denoise the encoded image latents.
223
- vae ([`AutoencoderKL`]):
224
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
225
- text_encoder ([`T5EncoderModel`]):
226
- Frozen text-encoder. OneDiffuser uses the T5 model as text encoder.
227
- tokenizer (`T5Tokenizer`):
228
- Tokenizer of class T5Tokenizer.
229
- scheduler ([`FlowMatchEulerDiscreteScheduler`]):
230
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
231
- """
232
-
233
- def __init__(
234
- self,
235
- transformer: NextDiT,
236
- vae: AutoencoderKL,
237
- text_encoder: T5EncoderModel,
238
- tokenizer: T5Tokenizer,
239
- scheduler: FlowMatchEulerDiscreteScheduler,
240
- ):
241
- super().__init__()
242
- self.register_modules(
243
- transformer=transformer,
244
- vae=vae,
245
- text_encoder=text_encoder,
246
- tokenizer=tokenizer,
247
- scheduler=scheduler,
248
- )
249
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
250
- self.image_processor = VaeImageProcessorOneDiffuser(vae_scale_factor=self.vae_scale_factor)
251
-
252
- def enable_vae_slicing(self):
253
- self.vae.enable_slicing()
254
-
255
- def disable_vae_slicing(self):
256
- self.vae.disable_slicing()
257
-
258
- def enable_sequential_cpu_offload(self, gpu_id=0):
259
- if is_accelerate_available():
260
- from accelerate import cpu_offload
261
- else:
262
- raise ImportError("Please install accelerate via `pip install accelerate`")
263
-
264
- device = torch.device(f"cuda:{gpu_id}")
265
-
266
- for cpu_offloaded_model in [self.transformer, self.text_encoder, self.vae]:
267
- if cpu_offloaded_model is not None:
268
- cpu_offload(cpu_offloaded_model, device)
269
-
270
- @property
271
- def _execution_device(self):
272
- if self.device != torch.device("meta") or not hasattr(self.transformer, "_hf_hook"):
273
- return self.device
274
- for module in self.transformer.modules():
275
- if (
276
- hasattr(module, "_hf_hook")
277
- and hasattr(module._hf_hook, "execution_device")
278
- and module._hf_hook.execution_device is not None
279
- ):
280
- return torch.device(module._hf_hook.execution_device)
281
- return self.device
282
-
283
- def encode_prompt(
284
- self,
285
- prompt,
286
- device,
287
- num_images_per_prompt,
288
- do_classifier_free_guidance,
289
- negative_prompt=None,
290
- max_length=300,
291
- ):
292
- batch_size = len(prompt) if isinstance(prompt, list) else 1
293
-
294
- text_inputs = self.tokenizer(
295
- prompt,
296
- padding="max_length",
297
- max_length=max_length,
298
- truncation=True,
299
- add_special_tokens=True,
300
- return_tensors="pt",
301
- )
302
- text_input_ids = text_inputs.input_ids
303
- attention_mask = text_inputs.attention_mask
304
-
305
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
306
-
307
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
308
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
309
- logger.warning(
310
- "The following part of your input was truncated because CLIP can only handle sequences up to"
311
- f" {max_length} tokens: {removed_text}"
312
- )
313
-
314
- text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
315
- prompt_embeds = text_encoder_output[0].to(torch.float32)
316
-
317
- # duplicate text embeddings for each generation per prompt, using mps friendly method
318
- bs_embed, seq_len, _ = prompt_embeds.shape
319
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
320
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
321
-
322
- # duplicate attention mask for each generation per prompt
323
- attention_mask = attention_mask.repeat(1, num_images_per_prompt)
324
- attention_mask = attention_mask.view(bs_embed * num_images_per_prompt, -1)
325
-
326
- # get unconditional embeddings for classifier free guidance
327
- if do_classifier_free_guidance:
328
- uncond_tokens: List[str]
329
- if negative_prompt is None:
330
- uncond_tokens = [""] * batch_size
331
- elif type(prompt) is not type(negative_prompt):
332
- raise TypeError(
333
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
334
- f" {type(prompt)}."
335
- )
336
- elif isinstance(negative_prompt, str):
337
- uncond_tokens = [negative_prompt]
338
- elif batch_size != len(negative_prompt):
339
- raise ValueError(
340
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
341
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
342
- " the batch size of `prompt`."
343
- )
344
- else:
345
- uncond_tokens = negative_prompt
346
-
347
- max_length = text_input_ids.shape[-1]
348
- uncond_input = self.tokenizer(
349
- uncond_tokens,
350
- padding="max_length",
351
- max_length=max_length,
352
- truncation=True,
353
- return_tensors="pt",
354
- )
355
-
356
- uncond_encoder_output = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device))
357
- negative_prompt_embeds = uncond_encoder_output[0].to(torch.float32)
358
-
359
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
360
- seq_len = negative_prompt_embeds.shape[1]
361
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
362
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
363
-
364
- # duplicate unconditional attention mask for each generation per prompt
365
- uncond_attention_mask = uncond_input.attention_mask.repeat(1, num_images_per_prompt)
366
- uncond_attention_mask = uncond_attention_mask.view(batch_size * num_images_per_prompt, -1)
367
-
368
- # For classifier free guidance, we need to do two forward passes.
369
- # Here we concatenate the unconditional and text embeddings into a single batch
370
- # to avoid doing two forward passes
371
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
372
- attention_mask = torch.cat([uncond_attention_mask, attention_mask])
373
-
374
- return prompt_embeds.to(device), attention_mask.to(device)
375
-
376
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
377
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
378
- if isinstance(generator, list) and len(generator) != batch_size:
379
- raise ValueError(
380
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
381
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
382
- )
383
-
384
- if latents is None:
385
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
386
- else:
387
- latents = latents.to(device)
388
-
389
- # scale the initial noise by the standard deviation required by the scheduler
390
- latents = latents * self.scheduler.init_noise_sigma
391
- return latents
392
-
393
- @torch.no_grad()
394
- @replace_example_docstring(EXAMPLE_DOC_STRING)
395
- def __call__(
396
- self,
397
- prompt: Union[str, List[str]] = None,
398
- height: Optional[int] = None,
399
- width: Optional[int] = None,
400
- num_inference_steps: int = 50,
401
- guidance_scale: float = 5.0,
402
- negative_prompt: Optional[Union[str, List[str]]] = None,
403
- num_images_per_prompt: Optional[int] = 1,
404
- eta: float = 0.0,
405
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
406
- latents: Optional[torch.FloatTensor] = None,
407
- output_type: Optional[str] = "pil",
408
- return_dict: bool = True,
409
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
410
- callback_steps: int = 1,
411
- forward_kwargs: Optional[Dict[str, Any]] = {},
412
- **kwargs,
413
- ):
414
- r"""
415
- Function invoked when calling the pipeline for generation.
416
-
417
- Args:
418
- prompt (`str` or `List[str]`, *optional*):
419
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
420
- height (`int`, *optional*, defaults to self.transformer.config.sample_size):
421
- The height in pixels of the generated image.
422
- width (`int`, *optional*, defaults to self.transformer.config.sample_size):
423
- The width in pixels of the generated image.
424
- num_inference_steps (`int`, *optional*, defaults to 50):
425
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
426
- expense of slower inference.
427
- guidance_scale (`float`, *optional*, defaults to 7.5):
428
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
429
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
430
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
431
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
432
- usually at the expense of lower image quality.
433
- negative_prompt (`str` or `List[str]`, *optional*):
434
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
435
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
436
- less than `1`).
437
- num_images_per_prompt (`int`, *optional*, defaults to 1):
438
- The number of images to generate per prompt.
439
- eta (`float`, *optional*, defaults to 0.0):
440
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
441
- [`schedulers.DDIMScheduler`], will be ignored for others.
442
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
443
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
444
- to make generation deterministic.
445
- latents (`torch.FloatTensor`, *optional*):
446
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
447
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
448
- tensor will ge generated by sampling using the supplied random `generator`.
449
- output_type (`str`, *optional*, defaults to `"pil"`):
450
- The output format of the generate image. Choose between
451
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
452
- return_dict (`bool`, *optional*, defaults to `True`):
453
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
454
- plain tuple.
455
- callback (`Callable`, *optional*):
456
- A function that will be called every `callback_steps` steps during inference. The function will be
457
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
458
- callback_steps (`int`, *optional*, defaults to 1):
459
- The frequency at which the `callback` function will be called. If not specified, the callback will be
460
- called at every step.
461
-
462
- Examples:
463
-
464
- Returns:
465
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
466
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
467
- When returning a tuple, the first element is a list with the generated images, and the second element is a
468
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
469
- (nsfw) content, according to the `safety_checker`.
470
- """
471
- height = height or self.transformer.config.input_size[-2] * 8 # TODO: Hardcoded downscale factor of vae
472
- width = width or self.transformer.config.input_size[-1] * 8
473
-
474
- # check inputs. Raise error if not correct
475
- self.check_inputs(prompt, height, width, callback_steps)
476
-
477
- # define call parameters
478
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
479
- device = self._execution_device
480
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
481
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf
482
- do_classifier_free_guidance = guidance_scale > 1.0
483
-
484
- encoder_hidden_states, encoder_attention_mask = self.encode_prompt(
485
- prompt,
486
- device,
487
- num_images_per_prompt,
488
- do_classifier_free_guidance,
489
- negative_prompt,
490
- )
491
-
492
- # set timesteps
493
- # # self.scheduler.set_timesteps(num_inference_steps, device=device)
494
- # timesteps = self.scheduler.timesteps
495
- timesteps = None
496
-
497
- # prepare latent variables
498
- num_channels_latents = self.transformer.config.in_channels
499
- latents = self.prepare_latents(
500
- batch_size * num_images_per_prompt,
501
- num_channels_latents,
502
- height,
503
- width,
504
- self.dtype,
505
- device,
506
- generator,
507
- latents,
508
- )
509
-
510
- # prepare extra step kwargs
511
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
512
-
513
- # 5. Prepare timesteps
514
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
515
- image_seq_len = latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
516
- mu = calculate_shift(
517
- image_seq_len,
518
- self.scheduler.config.base_image_seq_len,
519
- self.scheduler.config.max_image_seq_len,
520
- self.scheduler.config.base_shift,
521
- self.scheduler.config.max_shift,
522
- )
523
- timesteps, num_inference_steps = retrieve_timesteps(
524
- self.scheduler,
525
- num_inference_steps,
526
- device,
527
- timesteps,
528
- sigmas,
529
- mu=mu,
530
- )
531
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
532
- self._num_timesteps = len(timesteps)
533
-
534
- # denoising loop
535
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
536
- with self.progress_bar(total=num_inference_steps) as progress_bar:
537
- for i, t in enumerate(timesteps):
538
- # expand the latents if we are doing classifier free guidance
539
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
540
- # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
541
-
542
- # predict the noise residual
543
- noise_pred = self.transformer(
544
- samples=latent_model_input.to(self.dtype),
545
- timesteps=torch.tensor([t] * latent_model_input.shape[0], device=device),
546
- encoder_hidden_states=encoder_hidden_states.to(self.dtype),
547
- encoder_attention_mask=encoder_attention_mask,
548
- **forward_kwargs
549
- )
550
-
551
- # perform guidance
552
- if do_classifier_free_guidance:
553
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
554
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
555
-
556
- # compute the previous noisy sample x_t -> x_t-1
557
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
558
-
559
- # call the callback, if provided
560
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
561
- progress_bar.update()
562
- if callback is not None and i % callback_steps == 0:
563
- callback(i, t, latents)
564
-
565
- # scale and decode the image latents with vae
566
- latents = 1 / self.vae.config.scaling_factor * latents
567
- if latents.ndim == 5:
568
- latents = latents.squeeze(1)
569
- image = self.vae.decode(latents.to(self.vae.dtype)).sample
570
-
571
- image = (image / 2 + 0.5).clamp(0, 1)
572
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
573
-
574
- if output_type == "pil":
575
- image = self.numpy_to_pil(image)
576
-
577
- if not return_dict:
578
- return (image, None)
579
-
580
- return OneDiffusionPipelineOutput(images=image)
581
-
582
- @torch.no_grad()
583
- def img2img(
584
- self,
585
- prompt: Union[str, List[str]] = None,
586
- image: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
587
- height: Optional[int] = None,
588
- width: Optional[int] = None,
589
- num_inference_steps: int = 50,
590
- guidance_scale: float = 5.0,
591
- denoise_mask: Optional[List[int]] = [1, 0],
592
- negative_prompt: Optional[Union[str, List[str]]] = None,
593
- num_images_per_prompt: Optional[int] = 1,
594
- eta: float = 0.0,
595
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
596
- latents: Optional[torch.FloatTensor] = None,
597
- output_type: Optional[str] = "pil",
598
- return_dict: bool = True,
599
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
600
- callback_steps: int = 1,
601
- do_crop: bool = True,
602
- is_multiview: bool = False,
603
- multiview_azimuths: Optional[List[int]] = [0, 30, 60, 90],
604
- multiview_elevations: Optional[List[int]] = [0, 0, 0, 0],
605
- multiview_distances: float = 1.7,
606
- multiview_c2ws: Optional[List[torch.Tensor]] = None,
607
- multiview_intrinsics: Optional[torch.Tensor] = None,
608
- multiview_focal_length: float = 1.3887,
609
- forward_kwargs: Optional[Dict[str, Any]] = {},
610
- noise_scale: float = 1.0,
611
- **kwargs,
612
- ):
613
- # Convert single image to list for consistent handling
614
- if isinstance(image, PIL.Image.Image):
615
- image = [image]
616
-
617
- if height is None or width is None:
618
- closest_ar = get_closest_ratio(height=image[0].size[1], width=image[0].size[0], ratios=ASPECT_RATIO_512)
619
- height, width = int(closest_ar[0][0]), int(closest_ar[0][1])
620
-
621
- if not isinstance(multiview_distances, list) and not isinstance(multiview_distances, tuple):
622
- multiview_distances = [multiview_distances] * len(multiview_azimuths)
623
-
624
- # height = height or self.transformer.config.input_size[-2] * 8 # TODO: Hardcoded downscale factor of vae
625
- # width = width or self.transformer.config.input_size[-1] * 8
626
-
627
- # 1. check inputs. Raise error if not correct
628
- self.check_inputs(prompt, height, width, callback_steps)
629
-
630
- # Additional input validation for image list
631
- if not all(isinstance(img, PIL.Image.Image) for img in image):
632
- raise ValueError("All elements in image list must be PIL.Image objects")
633
-
634
- # 2. define call parameters
635
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
636
- device = self._execution_device
637
- do_classifier_free_guidance = guidance_scale > 1.0
638
-
639
- # 3. Encode input prompt
640
- encoder_hidden_states, encoder_attention_mask = self.encode_prompt(
641
- prompt,
642
- device,
643
- num_images_per_prompt,
644
- do_classifier_free_guidance,
645
- negative_prompt,
646
- )
647
-
648
- # 4. Preprocess all images
649
- if image is not None and len(image) > 0:
650
- processed_image = self.image_processor.preprocess(image, height=height, width=width, do_crop=do_crop)
651
- else:
652
- processed_image = None
653
-
654
- # # Stack processed images along the sequence dimension
655
- # if len(processed_images) > 1:
656
- # processed_image = torch.cat(processed_images, dim=0)
657
- # else:
658
- # processed_image = processed_images[0]
659
-
660
- timesteps = None
661
-
662
- # 6. prepare latent variables
663
- num_channels_latents = self.transformer.config.in_channels
664
- if processed_image is not None:
665
- cond_latents = self.prepare_latents(
666
- batch_size * num_images_per_prompt,
667
- num_channels_latents,
668
- height,
669
- width,
670
- self.dtype,
671
- device,
672
- generator,
673
- latents,
674
- image=processed_image,
675
- )
676
- else:
677
- cond_latents = None
678
-
679
- # 7. prepare extra step kwargs
680
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
681
- denoise_mask = torch.tensor(denoise_mask, device=device)
682
- denoise_indices = torch.where(denoise_mask == 1)[0]
683
- cond_indices = torch.where(denoise_mask == 0)[0]
684
- seq_length = denoise_mask.shape[0]
685
-
686
- latents = self.prepare_init_latents(
687
- batch_size * num_images_per_prompt,
688
- seq_length,
689
- num_channels_latents,
690
- height,
691
- width,
692
- self.dtype,
693
- device,
694
- generator,
695
- )
696
-
697
- # 5. Prepare timesteps
698
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
699
- # image_seq_len = latents.shape[1] * latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
700
- image_seq_len = noise_scale * sum(denoise_mask) * latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
701
- # image_seq_len = 256
702
- mu = calculate_shift(
703
- image_seq_len,
704
- self.scheduler.config.base_image_seq_len,
705
- self.scheduler.config.max_image_seq_len,
706
- self.scheduler.config.base_shift,
707
- self.scheduler.config.max_shift,
708
- )
709
- timesteps, num_inference_steps = retrieve_timesteps(
710
- self.scheduler,
711
- num_inference_steps,
712
- device,
713
- timesteps,
714
- sigmas,
715
- mu=mu,
716
- )
717
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
718
- self._num_timesteps = len(timesteps)
719
-
720
- if is_multiview:
721
- cond_indices_images = [index // 2 for index in cond_indices if index % 2 == 0]
722
- cond_indices_rays = [index // 2 for index in cond_indices if index % 2 == 1]
723
-
724
- multiview_elevations = [element for element in multiview_elevations if element is not None]
725
- multiview_azimuths = [element for element in multiview_azimuths if element is not None]
726
- multiview_distances = [element for element in multiview_distances if element is not None]
727
-
728
- if multiview_c2ws is None:
729
- multiview_c2ws = [
730
- torch.tensor(create_c2w_matrix(azimuth, elevation, distance)) for azimuth, elevation, distance in zip(multiview_azimuths, multiview_elevations, multiview_distances)
731
- ]
732
- c2ws = torch.stack(multiview_c2ws).float()
733
- else:
734
- c2ws = torch.Tensor(multiview_c2ws).float()
735
-
736
- c2ws[:, 0:3, 1:3] *= -1
737
- c2ws = c2ws[:, [1, 0, 2, 3], :]
738
- c2ws[:, 2, :] *= -1
739
-
740
- w2cs = torch.inverse(c2ws)
741
- if multiview_intrinsics is None:
742
- multiview_intrinsics = torch.Tensor([[[multiview_focal_length, 0, 0.5], [0, multiview_focal_length, 0.5], [0, 0, 1]]]).repeat(c2ws.shape[0], 1, 1)
743
- K = multiview_intrinsics
744
- Rs = w2cs[:, :3, :3]
745
- Ts = w2cs[:, :3, 3]
746
- sizes = torch.Tensor([[1, 1]]).repeat(c2ws.shape[0], 1)
747
-
748
- assert height == width
749
- cond_rays = calculate_rays(K, sizes, Rs, Ts, height // 8)
750
- cond_rays = cond_rays.reshape(-1, height // 8, width // 8, 6)
751
- # padding = (0, 10)
752
- # cond_rays = torch.nn.functional.pad(cond_rays, padding, "constant", 0)
753
- cond_rays = torch.cat([cond_rays, cond_rays, cond_rays[..., :4]], dim=-1) * 1.658
754
- cond_rays = cond_rays[None].repeat(batch_size * num_images_per_prompt, 1, 1, 1, 1)
755
- cond_rays = cond_rays.permute(0, 1, 4, 2, 3)
756
- cond_rays = cond_rays.to(device, dtype=self.dtype)
757
-
758
- latents = einops.rearrange(latents, "b (f n) c h w -> b f n c h w", n=2)
759
- if cond_latents is not None:
760
- latents[:, cond_indices_images, 0] = cond_latents
761
- latents[:, cond_indices_rays, 1] = cond_rays
762
- latents = einops.rearrange(latents, "b f n c h w -> b (f n) c h w")
763
- else:
764
- if cond_latents is not None:
765
- latents[:, cond_indices] = cond_latents
766
-
767
- # denoising loop
768
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
769
- with self.progress_bar(total=num_inference_steps) as progress_bar:
770
- for i, t in enumerate(timesteps):
771
- # expand the latents if we are doing classifier free guidance
772
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
773
- input_t = torch.broadcast_to(einops.repeat(torch.Tensor([t]).to(device), "1 -> 1 f 1 1 1", f=latent_model_input.shape[1]), latent_model_input.shape).clone()
774
-
775
- if is_multiview:
776
- input_t = einops.rearrange(input_t, "b (f n) c h w -> b f n c h w", n=2)
777
- input_t[:, cond_indices_images, 0] = self.scheduler.timesteps[-1]
778
- input_t[:, cond_indices_rays, 1] = self.scheduler.timesteps[-1]
779
- input_t = einops.rearrange(input_t, "b f n c h w -> b (f n) c h w")
780
- else:
781
- input_t[:, cond_indices] = self.scheduler.timesteps[-1]
782
-
783
- # predict the noise residual
784
- noise_pred = self.transformer(
785
- samples=latent_model_input.to(self.dtype),
786
- timesteps=input_t,
787
- encoder_hidden_states=encoder_hidden_states.to(self.dtype),
788
- encoder_attention_mask=encoder_attention_mask,
789
- **forward_kwargs
790
- )
791
-
792
- # perform guidance
793
- if do_classifier_free_guidance:
794
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
795
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
796
-
797
- # compute the previous noisy sample x_t -> x_t-1
798
- bs, n_frame = noise_pred.shape[:2]
799
- noise_pred = einops.rearrange(noise_pred, "b f c h w -> (b f) c h w")
800
- latents = einops.rearrange(latents, "b f c h w -> (b f) c h w")
801
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
802
- latents = einops.rearrange(latents, "(b f) c h w -> b f c h w", b=bs, f=n_frame)
803
- if is_multiview:
804
- latents = einops.rearrange(latents, "b (f n) c h w -> b f n c h w", n=2)
805
- if cond_latents is not None:
806
- latents[:, cond_indices_images, 0] = cond_latents
807
- latents[:, cond_indices_rays, 1] = cond_rays
808
- latents = einops.rearrange(latents, "b f n c h w -> b (f n) c h w")
809
- else:
810
- if cond_latents is not None:
811
- latents[:, cond_indices] = cond_latents
812
-
813
- # call the callback, if provided
814
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
815
- progress_bar.update()
816
- if callback is not None and i % callback_steps == 0:
817
- callback(i, t, latents)
818
-
819
- decoded_latents = latents / 1.658
820
- # scale and decode the image latents with vae
821
- latents = 1 / self.vae.config.scaling_factor * latents
822
- if latents.ndim == 5:
823
- latents = latents[:, denoise_indices]
824
- latents = einops.rearrange(latents, "b f c h w -> (b f) c h w")
825
- image = self.vae.decode(latents.to(self.vae.dtype)).sample
826
-
827
- image = (image / 2 + 0.5).clamp(0, 1)
828
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
829
-
830
- if output_type == "pil":
831
- image = self.numpy_to_pil(image)
832
-
833
- if not return_dict:
834
- return (image, None)
835
-
836
- return OneDiffusionPipelineOutput(images=image, latents=decoded_latents)
837
-
838
- def prepare_extra_step_kwargs(self, generator, eta):
839
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
840
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
841
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
842
- # and should be between [0, 1]
843
-
844
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
845
- extra_step_kwargs = {}
846
- if accepts_eta:
847
- extra_step_kwargs["eta"] = eta
848
-
849
- # check if the scheduler accepts generator
850
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
851
- if accepts_generator:
852
- extra_step_kwargs["generator"] = generator
853
- return extra_step_kwargs
854
-
855
- def check_inputs(self, prompt, height, width, callback_steps):
856
- if not isinstance(prompt, str) and not isinstance(prompt, list):
857
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
858
-
859
- if height % 16 != 0 or width % 16 != 0:
860
- raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
861
-
862
- if (callback_steps is None) or (
863
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
864
- ):
865
- raise ValueError(
866
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
867
- f" {type(callback_steps)}."
868
- )
869
-
870
- def get_timesteps(self, num_inference_steps, strength, device):
871
- # get the original timestep using init_timestep
872
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
873
-
874
- t_start = max(num_inference_steps - init_timestep, 0)
875
- timesteps = self.scheduler.timesteps[t_start:]
876
-
877
- return timesteps, num_inference_steps - t_start
878
-
879
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None):
880
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
881
- if isinstance(generator, list) and len(generator) != batch_size:
882
- raise ValueError(
883
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
884
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
885
- )
886
-
887
- if latents is None:
888
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
889
- else:
890
- latents = latents.to(device)
891
-
892
- if image is None:
893
- # scale the initial noise by the standard deviation required by the scheduler
894
- # latents = latents * self.scheduler.init_noise_sigma
895
- return latents
896
-
897
- image = image.to(device=device, dtype=dtype)
898
-
899
- if isinstance(generator, list) and len(generator) != batch_size:
900
- raise ValueError(
901
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
902
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
903
- )
904
- elif isinstance(generator, list):
905
- if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
906
- image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
907
- elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
908
- raise ValueError(
909
- f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
910
- )
911
- init_latents = [
912
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
913
- for i in range(batch_size)
914
- ]
915
- init_latents = torch.cat(init_latents, dim=0)
916
- else:
917
- init_latents = retrieve_latents(self.vae.encode(image.to(self.vae.dtype)), generator=generator)
918
-
919
- init_latents = self.vae.config.scaling_factor * init_latents
920
- init_latents = init_latents.to(device=device, dtype=dtype)
921
-
922
- init_latents = einops.rearrange(init_latents, "(bs views) c h w -> bs views c h w", bs=batch_size, views=init_latents.shape[0]//batch_size)
923
- # latents = einops.rearrange(latents, "b c h w -> b 1 c h w")
924
- # latents = torch.concat([latents, init_latents], dim=1)
925
- return init_latents
926
-
927
- def prepare_init_latents(self, batch_size, seq_length, num_channels_latents, height, width, dtype, device, generator, latents=None):
928
- shape = (batch_size, seq_length, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
929
- if isinstance(generator, list) and len(generator) != batch_size:
930
- raise ValueError(
931
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
932
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
933
- )
934
-
935
- if latents is None:
936
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
937
- else:
938
- latents = latents.to(device)
939
-
940
- return latents
941
-
942
- @torch.no_grad()
943
- def generate(
944
- self,
945
- prompt: Union[str, List[str]],
946
- num_inference_steps: int = 50,
947
- guidance_scale: float = 5.0,
948
- negative_prompt: Optional[Union[str, List[str]]] = None,
949
- num_images_per_prompt: Optional[int] = 1,
950
- height: Optional[int] = None,
951
- width: Optional[int] = None,
952
- eta: float = 0.0,
953
- generator: Optional[torch.Generator] = None,
954
- latents: Optional[torch.FloatTensor] = None,
955
- output_type: Optional[str] = "pil",
956
- return_dict: bool = True,
957
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
958
- callback_steps: Optional[int] = 1,
959
- ):
960
- """
961
- Function for image generation using the OneDiffusionPipeline.
962
- """
963
- return self(
964
- prompt=prompt,
965
- num_inference_steps=num_inference_steps,
966
- guidance_scale=guidance_scale,
967
- negative_prompt=negative_prompt,
968
- num_images_per_prompt=num_images_per_prompt,
969
- height=height,
970
- width=width,
971
- eta=eta,
972
- generator=generator,
973
- latents=latents,
974
- output_type=output_type,
975
- return_dict=return_dict,
976
- callback=callback,
977
- callback_steps=callback_steps,
978
- )
979
-
980
- @staticmethod
981
- def numpy_to_pil(images):
982
- """
983
- Convert a numpy image or a batch of images to a PIL image.
984
- """
985
- if images.ndim == 3:
986
- images = images[None, ...]
987
- images = (images * 255).round().astype("uint8")
988
- if images.shape[-1] == 1:
989
- # special case for grayscale (single channel) images
990
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
991
- else:
992
- pil_images = [Image.fromarray(image) for image in images]
993
-
994
- return pil_images
995
-
996
- @classmethod
997
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
998
- model_path = pretrained_model_name_or_path
999
- cache_dir = kwargs.pop("cache_dir", None)
1000
- force_download = kwargs.pop("force_download", False)
1001
- proxies = kwargs.pop("proxies", None)
1002
- local_files_only = kwargs.pop("local_files_only", None)
1003
- token = kwargs.pop("token", None)
1004
- revision = kwargs.pop("revision", None)
1005
- from_flax = kwargs.pop("from_flax", False)
1006
- torch_dtype = kwargs.pop("torch_dtype", None)
1007
- custom_pipeline = kwargs.pop("custom_pipeline", None)
1008
- custom_revision = kwargs.pop("custom_revision", None)
1009
- provider = kwargs.pop("provider", None)
1010
- sess_options = kwargs.pop("sess_options", None)
1011
- device_map = kwargs.pop("device_map", None)
1012
- max_memory = kwargs.pop("max_memory", None)
1013
- offload_folder = kwargs.pop("offload_folder", None)
1014
- offload_state_dict = kwargs.pop("offload_state_dict", False)
1015
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
1016
- variant = kwargs.pop("variant", None)
1017
- use_safetensors = kwargs.pop("use_safetensors", None)
1018
- use_onnx = kwargs.pop("use_onnx", None)
1019
- load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
1020
-
1021
- if low_cpu_mem_usage and not is_accelerate_available():
1022
- low_cpu_mem_usage = False
1023
- logger.warning(
1024
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
1025
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
1026
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
1027
- " install accelerate\n```\n."
1028
- )
1029
-
1030
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
1031
- raise NotImplementedError(
1032
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
1033
- " `low_cpu_mem_usage=False`."
1034
- )
1035
-
1036
- if device_map is not None and not is_torch_version(">=", "1.9.0"):
1037
- raise NotImplementedError(
1038
- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1039
- " `device_map=None`."
1040
- )
1041
-
1042
- if device_map is not None and not is_accelerate_available():
1043
- raise NotImplementedError(
1044
- "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
1045
- )
1046
-
1047
- if device_map is not None and not isinstance(device_map, str):
1048
- raise ValueError("`device_map` must be a string.")
1049
-
1050
- if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
1051
- raise NotImplementedError(
1052
- f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
1053
- )
1054
-
1055
- if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
1056
- if is_accelerate_version("<", "0.28.0"):
1057
- raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
1058
-
1059
- if low_cpu_mem_usage is False and device_map is not None:
1060
- raise ValueError(
1061
- f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
1062
- " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
1063
- )
1064
-
1065
- transformer = NextDiT.from_pretrained(f"{model_path}", subfolder="transformer", torch_dtype=torch.float32, cache_dir=cache_dir)
1066
- vae = AutoencoderKL.from_pretrained(f"{model_path}", subfolder="vae", cache_dir=cache_dir)
1067
- text_encoder = T5EncoderModel.from_pretrained(f"{model_path}", subfolder="text_encoder", torch_dtype=torch.float16, cache_dir=cache_dir)
1068
- tokenizer = T5Tokenizer.from_pretrained(model_path, subfolder="tokenizer", cache_dir=cache_dir)
1069
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler", cache_dir=cache_dir)
1070
-
1071
- pipeline = cls(
1072
- transformer=transformer,
1073
- vae=vae,
1074
- text_encoder=text_encoder,
1075
- tokenizer=tokenizer,
1076
- scheduler=scheduler,
1077
- **kwargs
1078
- )
1079
-
1080
- return pipeline