WeichenFan commited on
Commit
b453b6a
·
1 Parent(s): 2f39f58

Add application file

Browse files
Files changed (4) hide show
  1. app.py +161 -0
  2. sd3_pipeline.py +1170 -0
  3. video_infer.py +30 -0
  4. wan_pipeline.py +616 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sd3_pipeline import StableDiffusion3Pipeline
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ import os
7
+ import gc
8
+ import tempfile
9
+ import imageio
10
+ from diffusers import AutoencoderKLWan
11
+ from wan_pipeline import WanPipeline
12
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
13
+ from PIL import Image
14
+ from diffusers.utils import export_to_video
15
+
16
+ def set_seed(seed):
17
+ random.seed(seed)
18
+ os.environ['PYTHONHASHSEED'] = str(seed)
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed(seed)
22
+
23
+ # Model paths
24
+ model_paths = {
25
+ "sd3.5": "stabilityai/stable-diffusion-3.5-large",
26
+ "sd3": "stabilityai/stable-diffusion-3-medium-diffusers",
27
+ "wan-t2v": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
28
+ }
29
+
30
+ # Global variable for current model
31
+ current_model = None
32
+
33
+ # Folder to save video outputs
34
+ OUTPUT_DIR = "generated_videos"
35
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
36
+
37
+ def load_model(model_name):
38
+ global current_model
39
+ if current_model is not None:
40
+ del current_model # Delete the old model
41
+ torch.cuda.empty_cache() # Free GPU memory
42
+ gc.collect() # Force garbage collection
43
+
44
+ if "wan-t2v" in model_name:
45
+ vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.float32)
46
+ scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0)
47
+ current_model = WanPipeline.from_pretrained(model_paths[model_name], vae=vae, torch_dtype=torch.bfloat16).to("cuda")
48
+ current_model.scheduler = scheduler
49
+ else:
50
+ current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda")
51
+
52
+ return current_model
53
+
54
+ def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50, use_cfg_zero_star=True, use_zero_init=True, zero_steps=0, seed=None, compare_mode=False):
55
+ model = load_model(model_name)
56
+ if seed is None:
57
+ seed = random.randint(0, 2**32 - 1)
58
+ set_seed(seed)
59
+
60
+ is_video_model = "wan-t2v" in model_name
61
+
62
+ if is_video_model:
63
+ if compare_mode:
64
+ set_seed(seed)
65
+ video1_frames = model(
66
+ prompt=prompt,
67
+ guidance_scale=guidance_scale,
68
+ num_frames=81,
69
+ use_cfg_zero_star=True,
70
+ use_zero_init=use_zero_init,
71
+ zero_steps=zero_steps
72
+ ).frames[0]
73
+ video1_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG-Zero-Star.mp4")
74
+ export_to_video(video1_frames, video1_path, fps=16)
75
+
76
+ set_seed(seed)
77
+ video2_frames = model(
78
+ prompt=prompt,
79
+ guidance_scale=guidance_scale,
80
+ num_frames=81,
81
+ use_cfg_zero_star=False,
82
+ use_zero_init=use_zero_init,
83
+ zero_steps=zero_steps
84
+ ).frames[0]
85
+ video2_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG.mp4")
86
+ export_to_video(video2_frames, video2_path, fps=16)
87
+
88
+ return None, None, video1_path, video2_path, seed
89
+ else:
90
+ video_frames = model(
91
+ prompt=prompt,
92
+ guidance_scale=guidance_scale,
93
+ num_frames=81,
94
+ use_cfg_zero_star=use_cfg_zero_star,
95
+ use_zero_init=use_zero_init,
96
+ zero_steps=zero_steps
97
+ ).frames[0]
98
+ video_path = save_video(video_frames, f"{seed}.mp4")
99
+ return None, None, video_path, None, seed
100
+
101
+ if compare_mode:
102
+ set_seed(seed)
103
+ image1 = model(
104
+ prompt,
105
+ guidance_scale=guidance_scale,
106
+ num_inference_steps=num_inference_steps,
107
+ use_cfg_zero_star=True,
108
+ use_zero_init=use_zero_init,
109
+ zero_steps=zero_steps
110
+ ).images[0]
111
+
112
+ set_seed(seed)
113
+ image2 = model(
114
+ prompt,
115
+ guidance_scale=guidance_scale,
116
+ num_inference_steps=num_inference_steps,
117
+ use_cfg_zero_star=False,
118
+ use_zero_init=use_zero_init,
119
+ zero_steps=zero_steps
120
+ ).images[0]
121
+
122
+ return image1, image2, None, None, seed
123
+ else:
124
+ image = model(
125
+ prompt,
126
+ guidance_scale=guidance_scale,
127
+ num_inference_steps=num_inference_steps,
128
+ use_cfg_zero_star=use_cfg_zero_star,
129
+ use_zero_init=use_zero_init,
130
+ zero_steps=zero_steps
131
+ ).images[0]
132
+ if use_cfg_zero_star:
133
+ return image, None, None, None, seed
134
+ else:
135
+ return None, image, None, None, seed
136
+
137
+ # Gradio UI
138
+ demo = gr.Interface(
139
+ fn=generate_content,
140
+ inputs=[
141
+ gr.Textbox(value="A capybara holding a sign that reads Hello World", label="Enter your prompt"),
142
+ gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model"),
143
+ gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale"),
144
+ gr.Slider(10, 100, value=28, step=5, label="Inference Steps"),
145
+ gr.Checkbox(value=True, label="Use CFG Zero Star"),
146
+ gr.Checkbox(value=True, label="Use Zero Init"),
147
+ gr.Slider(0, 20, value=0, step=1, label="Zero out steps"),
148
+ gr.Number(value=42, label="Seed (Leave blank for random)"),
149
+ gr.Checkbox(value=True, label="Compare Mode")
150
+ ],
151
+ outputs=[
152
+ gr.Image(type="pil", label="CFG-Zero* Image"),
153
+ gr.Image(type="pil", label="CFG Image"),
154
+ gr.Video(label="CFG-Zero* Video"),
155
+ gr.Video(label="CFG Video"),
156
+ gr.Textbox(label="Used Seed")
157
+ ],
158
+ title="CFG-Zero*: Improved Classifier-Free Guidance for Flow Matching Models",
159
+ )
160
+
161
+ demo.launch(server_name="127.0.0.1", server_port=7860)
sd3_pipeline.py ADDED
@@ -0,0 +1,1170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ BaseImageProcessor,
21
+ CLIPTextModelWithProjection,
22
+ CLIPTokenizer,
23
+ PreTrainedModel,
24
+ T5EncoderModel,
25
+ T5TokenizerFast,
26
+ )
27
+
28
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
29
+ from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
30
+ from diffusers.models.autoencoders import AutoencoderKL
31
+ from diffusers.models.transformers import SD3Transformer2DModel
32
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> import torch
60
+ >>> from diffusers import StableDiffusion3Pipeline
61
+
62
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
63
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
64
+ ... )
65
+ >>> pipe.to("cuda")
66
+ >>> prompt = "A cat holding a sign that says hello world"
67
+ >>> image = pipe(prompt).images[0]
68
+ >>> image.save("sd3.png")
69
+ ```
70
+ """
71
+
72
+ def optimized_scale(positive_flat, negative_flat):
73
+
74
+ # Calculate dot production
75
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
76
+
77
+ # Squared norm of uncondition
78
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
79
+
80
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
81
+ st_star = dot_product / squared_norm
82
+
83
+ return st_star
84
+
85
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
86
+ def calculate_shift(
87
+ image_seq_len,
88
+ base_seq_len: int = 256,
89
+ max_seq_len: int = 4096,
90
+ base_shift: float = 0.5,
91
+ max_shift: float = 1.16,
92
+ ):
93
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
94
+ b = base_shift - m * base_seq_len
95
+ mu = image_seq_len * m + b
96
+ return mu
97
+
98
+
99
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
100
+ def retrieve_timesteps(
101
+ scheduler,
102
+ num_inference_steps: Optional[int] = None,
103
+ device: Optional[Union[str, torch.device]] = None,
104
+ timesteps: Optional[List[int]] = None,
105
+ sigmas: Optional[List[float]] = None,
106
+ **kwargs,
107
+ ):
108
+ r"""
109
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
110
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
111
+
112
+ Args:
113
+ scheduler (`SchedulerMixin`):
114
+ The scheduler to get timesteps from.
115
+ num_inference_steps (`int`):
116
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
117
+ must be `None`.
118
+ device (`str` or `torch.device`, *optional*):
119
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
120
+ timesteps (`List[int]`, *optional*):
121
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
122
+ `num_inference_steps` and `sigmas` must be `None`.
123
+ sigmas (`List[float]`, *optional*):
124
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
125
+ `num_inference_steps` and `timesteps` must be `None`.
126
+
127
+ Returns:
128
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
129
+ second element is the number of inference steps.
130
+ """
131
+ if timesteps is not None and sigmas is not None:
132
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
133
+ if timesteps is not None:
134
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
+ if not accepts_timesteps:
136
+ raise ValueError(
137
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
138
+ f" timestep schedules. Please check whether you are using the correct scheduler."
139
+ )
140
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ elif sigmas is not None:
144
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
145
+ if not accept_sigmas:
146
+ raise ValueError(
147
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
148
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
149
+ )
150
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ num_inference_steps = len(timesteps)
153
+ else:
154
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
155
+ timesteps = scheduler.timesteps
156
+ return timesteps, num_inference_steps
157
+
158
+
159
+ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
160
+ r"""
161
+ Args:
162
+ transformer ([`SD3Transformer2DModel`]):
163
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
164
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
165
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
166
+ vae ([`AutoencoderKL`]):
167
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
168
+ text_encoder ([`CLIPTextModelWithProjection`]):
169
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
170
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
171
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
172
+ as its dimension.
173
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
174
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
175
+ specifically the
176
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
177
+ variant.
178
+ text_encoder_3 ([`T5EncoderModel`]):
179
+ Frozen text-encoder. Stable Diffusion 3 uses
180
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
181
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
182
+ tokenizer (`CLIPTokenizer`):
183
+ Tokenizer of class
184
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
185
+ tokenizer_2 (`CLIPTokenizer`):
186
+ Second Tokenizer of class
187
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
188
+ tokenizer_3 (`T5TokenizerFast`):
189
+ Tokenizer of class
190
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
191
+ image_encoder (`PreTrainedModel`, *optional*):
192
+ Pre-trained Vision Model for IP Adapter.
193
+ feature_extractor (`BaseImageProcessor`, *optional*):
194
+ Image processor for IP Adapter.
195
+ """
196
+
197
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
198
+ _optional_components = ["image_encoder", "feature_extractor"]
199
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
200
+
201
+ def __init__(
202
+ self,
203
+ transformer: SD3Transformer2DModel,
204
+ scheduler: FlowMatchEulerDiscreteScheduler,
205
+ vae: AutoencoderKL,
206
+ text_encoder: CLIPTextModelWithProjection,
207
+ tokenizer: CLIPTokenizer,
208
+ text_encoder_2: CLIPTextModelWithProjection,
209
+ tokenizer_2: CLIPTokenizer,
210
+ text_encoder_3: T5EncoderModel,
211
+ tokenizer_3: T5TokenizerFast,
212
+ image_encoder: PreTrainedModel = None,
213
+ feature_extractor: BaseImageProcessor = None,
214
+ ):
215
+ super().__init__()
216
+
217
+ self.register_modules(
218
+ vae=vae,
219
+ text_encoder=text_encoder,
220
+ text_encoder_2=text_encoder_2,
221
+ text_encoder_3=text_encoder_3,
222
+ tokenizer=tokenizer,
223
+ tokenizer_2=tokenizer_2,
224
+ tokenizer_3=tokenizer_3,
225
+ transformer=transformer,
226
+ scheduler=scheduler,
227
+ image_encoder=image_encoder,
228
+ feature_extractor=feature_extractor,
229
+ )
230
+ self.vae_scale_factor = (
231
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
232
+ )
233
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
234
+ self.tokenizer_max_length = (
235
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
236
+ )
237
+ self.default_sample_size = (
238
+ self.transformer.config.sample_size
239
+ if hasattr(self, "transformer") and self.transformer is not None
240
+ else 128
241
+ )
242
+ self.patch_size = (
243
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
244
+ )
245
+
246
+ def _get_t5_prompt_embeds(
247
+ self,
248
+ prompt: Union[str, List[str]] = None,
249
+ num_images_per_prompt: int = 1,
250
+ max_sequence_length: int = 256,
251
+ device: Optional[torch.device] = None,
252
+ dtype: Optional[torch.dtype] = None,
253
+ ):
254
+ device = device or self._execution_device
255
+ dtype = dtype or self.text_encoder.dtype
256
+
257
+ prompt = [prompt] if isinstance(prompt, str) else prompt
258
+ batch_size = len(prompt)
259
+
260
+ if self.text_encoder_3 is None:
261
+ return torch.zeros(
262
+ (
263
+ batch_size * num_images_per_prompt,
264
+ self.tokenizer_max_length,
265
+ self.transformer.config.joint_attention_dim,
266
+ ),
267
+ device=device,
268
+ dtype=dtype,
269
+ )
270
+
271
+ text_inputs = self.tokenizer_3(
272
+ prompt,
273
+ padding="max_length",
274
+ max_length=max_sequence_length,
275
+ truncation=True,
276
+ add_special_tokens=True,
277
+ return_tensors="pt",
278
+ )
279
+ text_input_ids = text_inputs.input_ids
280
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
281
+
282
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
283
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
284
+ logger.warning(
285
+ "The following part of your input was truncated because `max_sequence_length` is set to "
286
+ f" {max_sequence_length} tokens: {removed_text}"
287
+ )
288
+
289
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
290
+
291
+ dtype = self.text_encoder_3.dtype
292
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
293
+
294
+ _, seq_len, _ = prompt_embeds.shape
295
+
296
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
297
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
298
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
299
+
300
+ return prompt_embeds
301
+
302
+ def _get_clip_prompt_embeds(
303
+ self,
304
+ prompt: Union[str, List[str]],
305
+ num_images_per_prompt: int = 1,
306
+ device: Optional[torch.device] = None,
307
+ clip_skip: Optional[int] = None,
308
+ clip_model_index: int = 0,
309
+ ):
310
+ device = device or self._execution_device
311
+
312
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
313
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
314
+
315
+ tokenizer = clip_tokenizers[clip_model_index]
316
+ text_encoder = clip_text_encoders[clip_model_index]
317
+
318
+ prompt = [prompt] if isinstance(prompt, str) else prompt
319
+ batch_size = len(prompt)
320
+
321
+ text_inputs = tokenizer(
322
+ prompt,
323
+ padding="max_length",
324
+ max_length=self.tokenizer_max_length,
325
+ truncation=True,
326
+ return_tensors="pt",
327
+ )
328
+
329
+ text_input_ids = text_inputs.input_ids
330
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
331
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
332
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
333
+ logger.warning(
334
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
335
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
336
+ )
337
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
338
+ pooled_prompt_embeds = prompt_embeds[0]
339
+
340
+ if clip_skip is None:
341
+ prompt_embeds = prompt_embeds.hidden_states[-2]
342
+ else:
343
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
344
+
345
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
346
+
347
+ _, seq_len, _ = prompt_embeds.shape
348
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
349
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
350
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
351
+
352
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
353
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
354
+
355
+ return prompt_embeds, pooled_prompt_embeds
356
+
357
+ def encode_prompt(
358
+ self,
359
+ prompt: Union[str, List[str]],
360
+ prompt_2: Union[str, List[str]],
361
+ prompt_3: Union[str, List[str]],
362
+ device: Optional[torch.device] = None,
363
+ num_images_per_prompt: int = 1,
364
+ do_classifier_free_guidance: bool = True,
365
+ negative_prompt: Optional[Union[str, List[str]]] = None,
366
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
367
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
368
+ prompt_embeds: Optional[torch.FloatTensor] = None,
369
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
370
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
371
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
372
+ clip_skip: Optional[int] = None,
373
+ max_sequence_length: int = 256,
374
+ lora_scale: Optional[float] = None,
375
+ ):
376
+ r"""
377
+
378
+ Args:
379
+ prompt (`str` or `List[str]`, *optional*):
380
+ prompt to be encoded
381
+ prompt_2 (`str` or `List[str]`, *optional*):
382
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
383
+ used in all text-encoders
384
+ prompt_3 (`str` or `List[str]`, *optional*):
385
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
386
+ used in all text-encoders
387
+ device: (`torch.device`):
388
+ torch device
389
+ num_images_per_prompt (`int`):
390
+ number of images that should be generated per prompt
391
+ do_classifier_free_guidance (`bool`):
392
+ whether to use classifier free guidance or not
393
+ negative_prompt (`str` or `List[str]`, *optional*):
394
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
395
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
396
+ less than `1`).
397
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
398
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
399
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
400
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
401
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
402
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
403
+ prompt_embeds (`torch.FloatTensor`, *optional*):
404
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
405
+ provided, text embeddings will be generated from `prompt` input argument.
406
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
407
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
408
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
409
+ argument.
410
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
411
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
412
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
413
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
414
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
415
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
416
+ input argument.
417
+ clip_skip (`int`, *optional*):
418
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
419
+ the output of the pre-final layer will be used for computing the prompt embeddings.
420
+ lora_scale (`float`, *optional*):
421
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
422
+ """
423
+ device = device or self._execution_device
424
+
425
+ # set lora scale so that monkey patched LoRA
426
+ # function of text encoder can correctly access it
427
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
428
+ self._lora_scale = lora_scale
429
+
430
+ # dynamically adjust the LoRA scale
431
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
432
+ scale_lora_layers(self.text_encoder, lora_scale)
433
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
434
+ scale_lora_layers(self.text_encoder_2, lora_scale)
435
+
436
+ prompt = [prompt] if isinstance(prompt, str) else prompt
437
+ if prompt is not None:
438
+ batch_size = len(prompt)
439
+ else:
440
+ batch_size = prompt_embeds.shape[0]
441
+
442
+ if prompt_embeds is None:
443
+ prompt_2 = prompt_2 or prompt
444
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
445
+
446
+ prompt_3 = prompt_3 or prompt
447
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
448
+
449
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
450
+ prompt=prompt,
451
+ device=device,
452
+ num_images_per_prompt=num_images_per_prompt,
453
+ clip_skip=clip_skip,
454
+ clip_model_index=0,
455
+ )
456
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
457
+ prompt=prompt_2,
458
+ device=device,
459
+ num_images_per_prompt=num_images_per_prompt,
460
+ clip_skip=clip_skip,
461
+ clip_model_index=1,
462
+ )
463
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
464
+
465
+ t5_prompt_embed = self._get_t5_prompt_embeds(
466
+ prompt=prompt_3,
467
+ num_images_per_prompt=num_images_per_prompt,
468
+ max_sequence_length=max_sequence_length,
469
+ device=device,
470
+ )
471
+ clip_prompt_embeds = torch.nn.functional.pad(
472
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
473
+ )
474
+
475
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
476
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
477
+
478
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
479
+ negative_prompt = negative_prompt or ""
480
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
481
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
482
+
483
+ # normalize str to list
484
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
485
+ negative_prompt_2 = (
486
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
487
+ )
488
+ negative_prompt_3 = (
489
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
490
+ )
491
+
492
+ if prompt is not None and type(prompt) is not type(negative_prompt):
493
+ raise TypeError(
494
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
495
+ f" {type(prompt)}."
496
+ )
497
+ elif batch_size != len(negative_prompt):
498
+ raise ValueError(
499
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
500
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
501
+ " the batch size of `prompt`."
502
+ )
503
+
504
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
505
+ negative_prompt,
506
+ device=device,
507
+ num_images_per_prompt=num_images_per_prompt,
508
+ clip_skip=None,
509
+ clip_model_index=0,
510
+ )
511
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
512
+ negative_prompt_2,
513
+ device=device,
514
+ num_images_per_prompt=num_images_per_prompt,
515
+ clip_skip=None,
516
+ clip_model_index=1,
517
+ )
518
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
519
+
520
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
521
+ prompt=negative_prompt_3,
522
+ num_images_per_prompt=num_images_per_prompt,
523
+ max_sequence_length=max_sequence_length,
524
+ device=device,
525
+ )
526
+
527
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
528
+ negative_clip_prompt_embeds,
529
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
530
+ )
531
+
532
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
533
+ negative_pooled_prompt_embeds = torch.cat(
534
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
535
+ )
536
+
537
+ if self.text_encoder is not None:
538
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
539
+ # Retrieve the original scale by scaling back the LoRA layers
540
+ unscale_lora_layers(self.text_encoder, lora_scale)
541
+
542
+ if self.text_encoder_2 is not None:
543
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
544
+ # Retrieve the original scale by scaling back the LoRA layers
545
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
546
+
547
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
548
+
549
+ def check_inputs(
550
+ self,
551
+ prompt,
552
+ prompt_2,
553
+ prompt_3,
554
+ height,
555
+ width,
556
+ negative_prompt=None,
557
+ negative_prompt_2=None,
558
+ negative_prompt_3=None,
559
+ prompt_embeds=None,
560
+ negative_prompt_embeds=None,
561
+ pooled_prompt_embeds=None,
562
+ negative_pooled_prompt_embeds=None,
563
+ callback_on_step_end_tensor_inputs=None,
564
+ max_sequence_length=None,
565
+ ):
566
+ if (
567
+ height % (self.vae_scale_factor * self.patch_size) != 0
568
+ or width % (self.vae_scale_factor * self.patch_size) != 0
569
+ ):
570
+ raise ValueError(
571
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
572
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
573
+ )
574
+
575
+ if callback_on_step_end_tensor_inputs is not None and not all(
576
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
577
+ ):
578
+ raise ValueError(
579
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
580
+ )
581
+
582
+ if prompt is not None and prompt_embeds is not None:
583
+ raise ValueError(
584
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
585
+ " only forward one of the two."
586
+ )
587
+ elif prompt_2 is not None and prompt_embeds is not None:
588
+ raise ValueError(
589
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
590
+ " only forward one of the two."
591
+ )
592
+ elif prompt_3 is not None and prompt_embeds is not None:
593
+ raise ValueError(
594
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
595
+ " only forward one of the two."
596
+ )
597
+ elif prompt is None and prompt_embeds is None:
598
+ raise ValueError(
599
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
600
+ )
601
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
602
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
603
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
604
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
605
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
606
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
607
+
608
+ if negative_prompt is not None and negative_prompt_embeds is not None:
609
+ raise ValueError(
610
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
611
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
612
+ )
613
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
614
+ raise ValueError(
615
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
616
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
617
+ )
618
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
619
+ raise ValueError(
620
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
621
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
622
+ )
623
+
624
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
625
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
626
+ raise ValueError(
627
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
628
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
629
+ f" {negative_prompt_embeds.shape}."
630
+ )
631
+
632
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
633
+ raise ValueError(
634
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
635
+ )
636
+
637
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
638
+ raise ValueError(
639
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
640
+ )
641
+
642
+ if max_sequence_length is not None and max_sequence_length > 512:
643
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
644
+
645
+ def prepare_latents(
646
+ self,
647
+ batch_size,
648
+ num_channels_latents,
649
+ height,
650
+ width,
651
+ dtype,
652
+ device,
653
+ generator,
654
+ latents=None,
655
+ ):
656
+ if latents is not None:
657
+ return latents.to(device=device, dtype=dtype)
658
+
659
+ shape = (
660
+ batch_size,
661
+ num_channels_latents,
662
+ int(height) // self.vae_scale_factor,
663
+ int(width) // self.vae_scale_factor,
664
+ )
665
+
666
+ if isinstance(generator, list) and len(generator) != batch_size:
667
+ raise ValueError(
668
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
669
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
670
+ )
671
+
672
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
673
+
674
+ return latents
675
+
676
+ @property
677
+ def guidance_scale(self):
678
+ return self._guidance_scale
679
+
680
+ @property
681
+ def skip_guidance_layers(self):
682
+ return self._skip_guidance_layers
683
+
684
+ @property
685
+ def clip_skip(self):
686
+ return self._clip_skip
687
+
688
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
689
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
690
+ # corresponds to doing no classifier free guidance.
691
+ @property
692
+ def do_classifier_free_guidance(self):
693
+ return self._guidance_scale > 1
694
+
695
+ @property
696
+ def joint_attention_kwargs(self):
697
+ return self._joint_attention_kwargs
698
+
699
+ @property
700
+ def num_timesteps(self):
701
+ return self._num_timesteps
702
+
703
+ @property
704
+ def interrupt(self):
705
+ return self._interrupt
706
+
707
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
708
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
709
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
710
+
711
+ Args:
712
+ image (`PipelineImageInput`):
713
+ Input image to be encoded.
714
+ device: (`torch.device`):
715
+ Torch device.
716
+
717
+ Returns:
718
+ `torch.Tensor`: The encoded image feature representation.
719
+ """
720
+ if not isinstance(image, torch.Tensor):
721
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
722
+
723
+ image = image.to(device=device, dtype=self.dtype)
724
+
725
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
726
+
727
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
728
+ def prepare_ip_adapter_image_embeds(
729
+ self,
730
+ ip_adapter_image: Optional[PipelineImageInput] = None,
731
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
732
+ device: Optional[torch.device] = None,
733
+ num_images_per_prompt: int = 1,
734
+ do_classifier_free_guidance: bool = True,
735
+ ) -> torch.Tensor:
736
+ """Prepares image embeddings for use in the IP-Adapter.
737
+
738
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
739
+
740
+ Args:
741
+ ip_adapter_image (`PipelineImageInput`, *optional*):
742
+ The input image to extract features from for IP-Adapter.
743
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
744
+ Precomputed image embeddings.
745
+ device: (`torch.device`, *optional*):
746
+ Torch device.
747
+ num_images_per_prompt (`int`, defaults to 1):
748
+ Number of images that should be generated per prompt.
749
+ do_classifier_free_guidance (`bool`, defaults to True):
750
+ Whether to use classifier free guidance or not.
751
+ """
752
+ device = device or self._execution_device
753
+
754
+ if ip_adapter_image_embeds is not None:
755
+ if do_classifier_free_guidance:
756
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
757
+ else:
758
+ single_image_embeds = ip_adapter_image_embeds
759
+ elif ip_adapter_image is not None:
760
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
761
+ if do_classifier_free_guidance:
762
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
763
+ else:
764
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
765
+
766
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
767
+
768
+ if do_classifier_free_guidance:
769
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
770
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
771
+
772
+ return image_embeds.to(device=device)
773
+
774
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
775
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
776
+ logger.warning(
777
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
778
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
779
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
780
+ )
781
+
782
+ super().enable_sequential_cpu_offload(*args, **kwargs)
783
+
784
+ @torch.no_grad()
785
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
786
+ def __call__(
787
+ self,
788
+ prompt: Union[str, List[str]] = None,
789
+ prompt_2: Optional[Union[str, List[str]]] = None,
790
+ prompt_3: Optional[Union[str, List[str]]] = None,
791
+ height: Optional[int] = None,
792
+ width: Optional[int] = None,
793
+ num_inference_steps: int = 28,
794
+ sigmas: Optional[List[float]] = None,
795
+ guidance_scale: float = 7.0,
796
+ negative_prompt: Optional[Union[str, List[str]]] = None,
797
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
798
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
799
+ num_images_per_prompt: Optional[int] = 1,
800
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
801
+ latents: Optional[torch.FloatTensor] = None,
802
+ prompt_embeds: Optional[torch.FloatTensor] = None,
803
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
804
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
805
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
806
+ ip_adapter_image: Optional[PipelineImageInput] = None,
807
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
808
+ output_type: Optional[str] = "pil",
809
+ return_dict: bool = True,
810
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
811
+ clip_skip: Optional[int] = None,
812
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
813
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
814
+ max_sequence_length: int = 256,
815
+ skip_guidance_layers: List[int] = None,
816
+ skip_layer_guidance_scale: float = 2.8,
817
+ skip_layer_guidance_stop: float = 0.2,
818
+ skip_layer_guidance_start: float = 0.01,
819
+ mu: Optional[float] = None,
820
+ use_cfg_zero_star: Optional[bool] = False,
821
+ use_zero_init: Optional[bool] = True,
822
+ zero_steps: Optional[int] = 0,
823
+ ):
824
+ r"""
825
+ Function invoked when calling the pipeline for generation.
826
+
827
+ Args:
828
+ prompt (`str` or `List[str]`, *optional*):
829
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
830
+ instead.
831
+ prompt_2 (`str` or `List[str]`, *optional*):
832
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
833
+ will be used instead
834
+ prompt_3 (`str` or `List[str]`, *optional*):
835
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
836
+ will be used instead
837
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
838
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
839
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
840
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
841
+ num_inference_steps (`int`, *optional*, defaults to 50):
842
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
843
+ expense of slower inference.
844
+ sigmas (`List[float]`, *optional*):
845
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
846
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
847
+ will be used.
848
+ guidance_scale (`float`, *optional*, defaults to 7.0):
849
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
850
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
851
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
852
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
853
+ usually at the expense of lower image quality.
854
+ negative_prompt (`str` or `List[str]`, *optional*):
855
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
856
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
857
+ less than `1`).
858
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
859
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
860
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
861
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
862
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
863
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
864
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
865
+ The number of images to generate per prompt.
866
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
867
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
868
+ to make generation deterministic.
869
+ latents (`torch.FloatTensor`, *optional*):
870
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
871
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
872
+ tensor will ge generated by sampling using the supplied random `generator`.
873
+ prompt_embeds (`torch.FloatTensor`, *optional*):
874
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
875
+ provided, text embeddings will be generated from `prompt` input argument.
876
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
877
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
878
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
879
+ argument.
880
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
881
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
882
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
883
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
884
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
885
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
886
+ input argument.
887
+ ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
888
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
889
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
890
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
891
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
892
+ output_type (`str`, *optional*, defaults to `"pil"`):
893
+ The output format of the generate image. Choose between
894
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
895
+ return_dict (`bool`, *optional*, defaults to `True`):
896
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
897
+ a plain tuple.
898
+ joint_attention_kwargs (`dict`, *optional*):
899
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
900
+ `self.processor` in
901
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
902
+ callback_on_step_end (`Callable`, *optional*):
903
+ A function that calls at the end of each denoising steps during the inference. The function is called
904
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
905
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
906
+ `callback_on_step_end_tensor_inputs`.
907
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
908
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
909
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
910
+ `._callback_tensor_inputs` attribute of your pipeline class.
911
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
912
+ skip_guidance_layers (`List[int]`, *optional*):
913
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
914
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
915
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
916
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
917
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
918
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
919
+ with a scale of `1`.
920
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
921
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
922
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
923
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
924
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
925
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
926
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
927
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
928
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
929
+
930
+ Examples:
931
+
932
+ Returns:
933
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
934
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
935
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
936
+ """
937
+
938
+ height = height or self.default_sample_size * self.vae_scale_factor
939
+ width = width or self.default_sample_size * self.vae_scale_factor
940
+
941
+ # 1. Check inputs. Raise error if not correct
942
+ self.check_inputs(
943
+ prompt,
944
+ prompt_2,
945
+ prompt_3,
946
+ height,
947
+ width,
948
+ negative_prompt=negative_prompt,
949
+ negative_prompt_2=negative_prompt_2,
950
+ negative_prompt_3=negative_prompt_3,
951
+ prompt_embeds=prompt_embeds,
952
+ negative_prompt_embeds=negative_prompt_embeds,
953
+ pooled_prompt_embeds=pooled_prompt_embeds,
954
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
955
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
956
+ max_sequence_length=max_sequence_length,
957
+ )
958
+
959
+ self._guidance_scale = guidance_scale
960
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
961
+ self._clip_skip = clip_skip
962
+ self._joint_attention_kwargs = joint_attention_kwargs
963
+ self._interrupt = False
964
+
965
+ # 2. Define call parameters
966
+ if prompt is not None and isinstance(prompt, str):
967
+ batch_size = 1
968
+ elif prompt is not None and isinstance(prompt, list):
969
+ batch_size = len(prompt)
970
+ else:
971
+ batch_size = prompt_embeds.shape[0]
972
+
973
+ device = self._execution_device
974
+
975
+ lora_scale = (
976
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
977
+ )
978
+ (
979
+ prompt_embeds,
980
+ negative_prompt_embeds,
981
+ pooled_prompt_embeds,
982
+ negative_pooled_prompt_embeds,
983
+ ) = self.encode_prompt(
984
+ prompt=prompt,
985
+ prompt_2=prompt_2,
986
+ prompt_3=prompt_3,
987
+ negative_prompt=negative_prompt,
988
+ negative_prompt_2=negative_prompt_2,
989
+ negative_prompt_3=negative_prompt_3,
990
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
991
+ prompt_embeds=prompt_embeds,
992
+ negative_prompt_embeds=negative_prompt_embeds,
993
+ pooled_prompt_embeds=pooled_prompt_embeds,
994
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
995
+ device=device,
996
+ clip_skip=self.clip_skip,
997
+ num_images_per_prompt=num_images_per_prompt,
998
+ max_sequence_length=max_sequence_length,
999
+ lora_scale=lora_scale,
1000
+ )
1001
+ if self.do_classifier_free_guidance:
1002
+ if skip_guidance_layers is not None:
1003
+ original_prompt_embeds = prompt_embeds
1004
+ original_pooled_prompt_embeds = pooled_prompt_embeds
1005
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1006
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1007
+
1008
+ # 4. Prepare latent variables
1009
+ num_channels_latents = self.transformer.config.in_channels
1010
+ latents = self.prepare_latents(
1011
+ batch_size * num_images_per_prompt,
1012
+ num_channels_latents,
1013
+ height,
1014
+ width,
1015
+ prompt_embeds.dtype,
1016
+ device,
1017
+ generator,
1018
+ latents,
1019
+ )
1020
+
1021
+ # 5. Prepare timesteps
1022
+ scheduler_kwargs = {}
1023
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1024
+ _, _, height, width = latents.shape
1025
+ image_seq_len = (height // self.transformer.config.patch_size) * (
1026
+ width // self.transformer.config.patch_size
1027
+ )
1028
+ mu = calculate_shift(
1029
+ image_seq_len,
1030
+ self.scheduler.config.base_image_seq_len,
1031
+ self.scheduler.config.max_image_seq_len,
1032
+ self.scheduler.config.base_shift,
1033
+ self.scheduler.config.max_shift,
1034
+ )
1035
+ scheduler_kwargs["mu"] = mu
1036
+ elif mu is not None:
1037
+ scheduler_kwargs["mu"] = mu
1038
+ timesteps, num_inference_steps = retrieve_timesteps(
1039
+ self.scheduler,
1040
+ num_inference_steps,
1041
+ device,
1042
+ sigmas=sigmas,
1043
+ **scheduler_kwargs,
1044
+ )
1045
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1046
+ self._num_timesteps = len(timesteps)
1047
+
1048
+ # 6. Prepare image embeddings
1049
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1050
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1051
+ ip_adapter_image,
1052
+ ip_adapter_image_embeds,
1053
+ device,
1054
+ batch_size * num_images_per_prompt,
1055
+ self.do_classifier_free_guidance,
1056
+ )
1057
+
1058
+ if self.joint_attention_kwargs is None:
1059
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1060
+ else:
1061
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1062
+
1063
+ sigmas = timesteps / self.scheduler.config.num_train_timesteps
1064
+
1065
+ # 7. Denoising loop
1066
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1067
+ for i, t in enumerate(timesteps):
1068
+ if self.interrupt:
1069
+ continue
1070
+
1071
+ # expand the latents if we are doing classifier free guidance
1072
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1073
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1074
+ timestep = t.expand(latent_model_input.shape[0])
1075
+
1076
+ noise_pred = self.transformer(
1077
+ hidden_states=latent_model_input,
1078
+ timestep=timestep,
1079
+ encoder_hidden_states=prompt_embeds,
1080
+ pooled_projections=pooled_prompt_embeds,
1081
+ joint_attention_kwargs=self.joint_attention_kwargs,
1082
+ return_dict=False,
1083
+ )[0]
1084
+
1085
+ # perform guidance
1086
+ if self.do_classifier_free_guidance:
1087
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1088
+
1089
+ if use_cfg_zero_star:
1090
+ positive_flat = noise_pred_text.view(batch_size, -1)
1091
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
1092
+
1093
+ alpha = optimized_scale(positive_flat,negative_flat)
1094
+ alpha = alpha.view(batch_size, 1, 1, 1)
1095
+
1096
+ if (i <= zero_steps) and use_zero_init:
1097
+ noise_pred = noise_pred_text*0.
1098
+ else:
1099
+ noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_text - noise_pred_uncond * alpha)
1100
+ else:
1101
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1102
+
1103
+ should_skip_layers = (
1104
+ True
1105
+ if i > num_inference_steps * skip_layer_guidance_start
1106
+ and i < num_inference_steps * skip_layer_guidance_stop
1107
+ else False
1108
+ )
1109
+ if skip_guidance_layers is not None and should_skip_layers:
1110
+ timestep = t.expand(latents.shape[0])
1111
+ latent_model_input = latents
1112
+ noise_pred_skip_layers = self.transformer(
1113
+ hidden_states=latent_model_input,
1114
+ timestep=timestep,
1115
+ encoder_hidden_states=original_prompt_embeds,
1116
+ pooled_projections=original_pooled_prompt_embeds,
1117
+ joint_attention_kwargs=self.joint_attention_kwargs,
1118
+ return_dict=False,
1119
+ skip_layers=skip_guidance_layers,
1120
+ )[0]
1121
+ noise_pred = (
1122
+ noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
1123
+ )
1124
+
1125
+ # compute the previous noisy sample x_t -> x_t-1
1126
+ latents_dtype = latents.dtype
1127
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1128
+ # latents = noise_pred
1129
+
1130
+ if latents.dtype != latents_dtype:
1131
+ if torch.backends.mps.is_available():
1132
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1133
+ latents = latents.to(latents_dtype)
1134
+
1135
+ if callback_on_step_end is not None:
1136
+ callback_kwargs = {}
1137
+ for k in callback_on_step_end_tensor_inputs:
1138
+ callback_kwargs[k] = locals()[k]
1139
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1140
+
1141
+ latents = callback_outputs.pop("latents", latents)
1142
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1143
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1144
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1145
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1146
+ )
1147
+
1148
+ # call the callback, if provided
1149
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1150
+ progress_bar.update()
1151
+
1152
+ if XLA_AVAILABLE:
1153
+ xm.mark_step()
1154
+
1155
+ if output_type == "latent":
1156
+ image = latents
1157
+
1158
+ else:
1159
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1160
+
1161
+ image = self.vae.decode(latents, return_dict=False)[0]
1162
+ image = self.image_processor.postprocess(image, output_type=output_type)
1163
+
1164
+ # Offload all models
1165
+ self.maybe_free_model_hooks()
1166
+
1167
+ if not return_dict:
1168
+ return (image,)
1169
+
1170
+ return StableDiffusion3PipelineOutput(images=image)
video_infer.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.utils import export_to_video
3
+ from diffusers import AutoencoderKLWan#, WanPipeline
4
+ from wan_pipeline import WanPipeline
5
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
6
+
7
+ # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
8
+ model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
9
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
10
+ flow_shift = 8.0 # 5.0 for 720P, 3.0 for 480P
11
+ scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift)
12
+ pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
13
+ pipe.scheduler = scheduler
14
+ pipe.to("cuda")
15
+
16
+ prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
17
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
18
+
19
+ output = pipe(
20
+ prompt=prompt,
21
+ negative_prompt=negative_prompt,
22
+ height=480,
23
+ width=832,
24
+ num_frames=81,
25
+ guidance_scale=6.0,
26
+ use_cfg_zero_star=True,
27
+ use_zero_init=True,
28
+ zero_steps=0
29
+ ).frames[0]
30
+ export_to_video(output, "output.mp4", fps=16)
wan_pipeline.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import ftfy
19
+ import regex as re
20
+ import torch
21
+ from transformers import AutoTokenizer, UMT5EncoderModel
22
+
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from diffusers.loaders import WanLoraLoaderMixin
25
+ from diffusers.models import AutoencoderKLWan, WanTransformer3DModel
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.video_processor import VideoProcessor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
32
+
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ EXAMPLE_DOC_STRING = """
45
+ Examples:
46
+ ```python
47
+ >>> import torch
48
+ >>> from diffusers.utils import export_to_video
49
+ >>> from diffusers import AutoencoderKLWan, WanPipeline
50
+ >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
51
+
52
+ >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
53
+ >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
54
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
55
+ >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
56
+ >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
57
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
58
+ >>> pipe.to("cuda")
59
+
60
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
61
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
62
+
63
+ >>> output = pipe(
64
+ ... prompt=prompt,
65
+ ... negative_prompt=negative_prompt,
66
+ ... height=720,
67
+ ... width=1280,
68
+ ... num_frames=81,
69
+ ... guidance_scale=5.0,
70
+ ... ).frames[0]
71
+ >>> export_to_video(output, "output.mp4", fps=16)
72
+ ```
73
+ """
74
+
75
+ def optimized_scale(positive_flat, negative_flat):
76
+
77
+ # Calculate dot production
78
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
79
+
80
+ # Squared norm of uncondition
81
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
82
+
83
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
84
+ st_star = dot_product / squared_norm
85
+
86
+ return st_star
87
+
88
+ def basic_clean(text):
89
+ text = ftfy.fix_text(text)
90
+ text = html.unescape(html.unescape(text))
91
+ return text.strip()
92
+
93
+
94
+ def whitespace_clean(text):
95
+ text = re.sub(r"\s+", " ", text)
96
+ text = text.strip()
97
+ return text
98
+
99
+
100
+ def prompt_clean(text):
101
+ text = whitespace_clean(basic_clean(text))
102
+ return text
103
+
104
+
105
+ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
106
+ r"""
107
+ Pipeline for text-to-video generation using Wan.
108
+
109
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
110
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
111
+
112
+ Args:
113
+ tokenizer ([`T5Tokenizer`]):
114
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
115
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
116
+ text_encoder ([`T5EncoderModel`]):
117
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
118
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
119
+ transformer ([`WanTransformer3DModel`]):
120
+ Conditional Transformer to denoise the input latents.
121
+ scheduler ([`UniPCMultistepScheduler`]):
122
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
123
+ vae ([`AutoencoderKLWan`]):
124
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
125
+ """
126
+
127
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
128
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
129
+
130
+ def __init__(
131
+ self,
132
+ tokenizer: AutoTokenizer,
133
+ text_encoder: UMT5EncoderModel,
134
+ transformer: WanTransformer3DModel,
135
+ vae: AutoencoderKLWan,
136
+ scheduler: FlowMatchEulerDiscreteScheduler,
137
+ ):
138
+ super().__init__()
139
+
140
+ self.register_modules(
141
+ vae=vae,
142
+ text_encoder=text_encoder,
143
+ tokenizer=tokenizer,
144
+ transformer=transformer,
145
+ scheduler=scheduler,
146
+ )
147
+
148
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
149
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
150
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
151
+
152
+ def _get_t5_prompt_embeds(
153
+ self,
154
+ prompt: Union[str, List[str]] = None,
155
+ num_videos_per_prompt: int = 1,
156
+ max_sequence_length: int = 226,
157
+ device: Optional[torch.device] = None,
158
+ dtype: Optional[torch.dtype] = None,
159
+ ):
160
+ device = device or self._execution_device
161
+ dtype = dtype or self.text_encoder.dtype
162
+
163
+ prompt = [prompt] if isinstance(prompt, str) else prompt
164
+ prompt = [prompt_clean(u) for u in prompt]
165
+ batch_size = len(prompt)
166
+
167
+ text_inputs = self.tokenizer(
168
+ prompt,
169
+ padding="max_length",
170
+ max_length=max_sequence_length,
171
+ truncation=True,
172
+ add_special_tokens=True,
173
+ return_attention_mask=True,
174
+ return_tensors="pt",
175
+ )
176
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
177
+ seq_lens = mask.gt(0).sum(dim=1).long()
178
+
179
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
180
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
181
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
182
+ prompt_embeds = torch.stack(
183
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
184
+ )
185
+
186
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
187
+ _, seq_len, _ = prompt_embeds.shape
188
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
189
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
190
+
191
+ return prompt_embeds
192
+
193
+ def encode_prompt(
194
+ self,
195
+ prompt: Union[str, List[str]],
196
+ negative_prompt: Optional[Union[str, List[str]]] = None,
197
+ do_classifier_free_guidance: bool = True,
198
+ num_videos_per_prompt: int = 1,
199
+ prompt_embeds: Optional[torch.Tensor] = None,
200
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
201
+ max_sequence_length: int = 226,
202
+ device: Optional[torch.device] = None,
203
+ dtype: Optional[torch.dtype] = None,
204
+ ):
205
+ r"""
206
+ Encodes the prompt into text encoder hidden states.
207
+
208
+ Args:
209
+ prompt (`str` or `List[str]`, *optional*):
210
+ prompt to be encoded
211
+ negative_prompt (`str` or `List[str]`, *optional*):
212
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
213
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
214
+ less than `1`).
215
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
216
+ Whether to use classifier free guidance or not.
217
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
218
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
219
+ prompt_embeds (`torch.Tensor`, *optional*):
220
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
221
+ provided, text embeddings will be generated from `prompt` input argument.
222
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
223
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
224
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
225
+ argument.
226
+ device: (`torch.device`, *optional*):
227
+ torch device
228
+ dtype: (`torch.dtype`, *optional*):
229
+ torch dtype
230
+ """
231
+ device = device or self._execution_device
232
+
233
+ prompt = [prompt] if isinstance(prompt, str) else prompt
234
+ if prompt is not None:
235
+ batch_size = len(prompt)
236
+ else:
237
+ batch_size = prompt_embeds.shape[0]
238
+
239
+ if prompt_embeds is None:
240
+ prompt_embeds = self._get_t5_prompt_embeds(
241
+ prompt=prompt,
242
+ num_videos_per_prompt=num_videos_per_prompt,
243
+ max_sequence_length=max_sequence_length,
244
+ device=device,
245
+ dtype=dtype,
246
+ )
247
+
248
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
249
+ negative_prompt = negative_prompt or ""
250
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
251
+
252
+ if prompt is not None and type(prompt) is not type(negative_prompt):
253
+ raise TypeError(
254
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
255
+ f" {type(prompt)}."
256
+ )
257
+ elif batch_size != len(negative_prompt):
258
+ raise ValueError(
259
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
260
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
261
+ " the batch size of `prompt`."
262
+ )
263
+
264
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
265
+ prompt=negative_prompt,
266
+ num_videos_per_prompt=num_videos_per_prompt,
267
+ max_sequence_length=max_sequence_length,
268
+ device=device,
269
+ dtype=dtype,
270
+ )
271
+
272
+ return prompt_embeds, negative_prompt_embeds
273
+
274
+ def check_inputs(
275
+ self,
276
+ prompt,
277
+ negative_prompt,
278
+ height,
279
+ width,
280
+ prompt_embeds=None,
281
+ negative_prompt_embeds=None,
282
+ callback_on_step_end_tensor_inputs=None,
283
+ ):
284
+ if height % 16 != 0 or width % 16 != 0:
285
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
286
+
287
+ if callback_on_step_end_tensor_inputs is not None and not all(
288
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
289
+ ):
290
+ raise ValueError(
291
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
292
+ )
293
+
294
+ if prompt is not None and prompt_embeds is not None:
295
+ raise ValueError(
296
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
297
+ " only forward one of the two."
298
+ )
299
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
300
+ raise ValueError(
301
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
302
+ " only forward one of the two."
303
+ )
304
+ elif prompt is None and prompt_embeds is None:
305
+ raise ValueError(
306
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
307
+ )
308
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
309
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
310
+ elif negative_prompt is not None and (
311
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
312
+ ):
313
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
314
+
315
+ def prepare_latents(
316
+ self,
317
+ batch_size: int,
318
+ num_channels_latents: int = 16,
319
+ height: int = 480,
320
+ width: int = 832,
321
+ num_frames: int = 81,
322
+ dtype: Optional[torch.dtype] = None,
323
+ device: Optional[torch.device] = None,
324
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
325
+ latents: Optional[torch.Tensor] = None,
326
+ ) -> torch.Tensor:
327
+ if latents is not None:
328
+ return latents.to(device=device, dtype=dtype)
329
+
330
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
331
+ shape = (
332
+ batch_size,
333
+ num_channels_latents,
334
+ num_latent_frames,
335
+ int(height) // self.vae_scale_factor_spatial,
336
+ int(width) // self.vae_scale_factor_spatial,
337
+ )
338
+ if isinstance(generator, list) and len(generator) != batch_size:
339
+ raise ValueError(
340
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
341
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
342
+ )
343
+
344
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
345
+ return latents
346
+
347
+ @property
348
+ def guidance_scale(self):
349
+ return self._guidance_scale
350
+
351
+ @property
352
+ def do_classifier_free_guidance(self):
353
+ return self._guidance_scale > 1.0
354
+
355
+ @property
356
+ def num_timesteps(self):
357
+ return self._num_timesteps
358
+
359
+ @property
360
+ def current_timestep(self):
361
+ return self._current_timestep
362
+
363
+ @property
364
+ def interrupt(self):
365
+ return self._interrupt
366
+
367
+ @property
368
+ def attention_kwargs(self):
369
+ return self._attention_kwargs
370
+
371
+ @torch.no_grad()
372
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
373
+ def __call__(
374
+ self,
375
+ prompt: Union[str, List[str]] = None,
376
+ negative_prompt: Union[str, List[str]] = None,
377
+ height: int = 480,
378
+ width: int = 832,
379
+ num_frames: int = 81,
380
+ num_inference_steps: int = 50,
381
+ guidance_scale: float = 5.0,
382
+ num_videos_per_prompt: Optional[int] = 1,
383
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
384
+ latents: Optional[torch.Tensor] = None,
385
+ prompt_embeds: Optional[torch.Tensor] = None,
386
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
387
+ output_type: Optional[str] = "np",
388
+ return_dict: bool = True,
389
+ attention_kwargs: Optional[Dict[str, Any]] = None,
390
+ callback_on_step_end: Optional[
391
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
392
+ ] = None,
393
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
394
+ max_sequence_length: int = 512,
395
+ use_cfg_zero_star: Optional[bool] = False,
396
+ use_zero_init: Optional[bool] = True,
397
+ zero_steps: Optional[int] = 0,
398
+ ):
399
+ r"""
400
+ The call function to the pipeline for generation.
401
+
402
+ Args:
403
+ prompt (`str` or `List[str]`, *optional*):
404
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
405
+ instead.
406
+ height (`int`, defaults to `480`):
407
+ The height in pixels of the generated image.
408
+ width (`int`, defaults to `832`):
409
+ The width in pixels of the generated image.
410
+ num_frames (`int`, defaults to `81`):
411
+ The number of frames in the generated video.
412
+ num_inference_steps (`int`, defaults to `50`):
413
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
414
+ expense of slower inference.
415
+ guidance_scale (`float`, defaults to `5.0`):
416
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
417
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
418
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
419
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
420
+ usually at the expense of lower image quality.
421
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
422
+ The number of images to generate per prompt.
423
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
424
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
425
+ generation deterministic.
426
+ latents (`torch.Tensor`, *optional*):
427
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
428
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
429
+ tensor is generated by sampling using the supplied random `generator`.
430
+ prompt_embeds (`torch.Tensor`, *optional*):
431
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
432
+ provided, text embeddings are generated from the `prompt` input argument.
433
+ output_type (`str`, *optional*, defaults to `"pil"`):
434
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
435
+ return_dict (`bool`, *optional*, defaults to `True`):
436
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
437
+ attention_kwargs (`dict`, *optional*):
438
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
439
+ `self.processor` in
440
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
441
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
442
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
443
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
444
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
445
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
446
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
447
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
448
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
449
+ `._callback_tensor_inputs` attribute of your pipeline class.
450
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
451
+ The dtype to use for the torch.amp.autocast.
452
+
453
+ Examples:
454
+
455
+ Returns:
456
+ [`~WanPipelineOutput`] or `tuple`:
457
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
458
+ the first element is a list with the generated images and the second element is a list of `bool`s
459
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
460
+ """
461
+
462
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
463
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
464
+
465
+ # 1. Check inputs. Raise error if not correct
466
+ self.check_inputs(
467
+ prompt,
468
+ negative_prompt,
469
+ height,
470
+ width,
471
+ prompt_embeds,
472
+ negative_prompt_embeds,
473
+ callback_on_step_end_tensor_inputs,
474
+ )
475
+
476
+ self._guidance_scale = guidance_scale
477
+ self._attention_kwargs = attention_kwargs
478
+ self._current_timestep = None
479
+ self._interrupt = False
480
+
481
+ device = self._execution_device
482
+
483
+ # 2. Define call parameters
484
+ if prompt is not None and isinstance(prompt, str):
485
+ batch_size = 1
486
+ elif prompt is not None and isinstance(prompt, list):
487
+ batch_size = len(prompt)
488
+ else:
489
+ batch_size = prompt_embeds.shape[0]
490
+
491
+ # 3. Encode input prompt
492
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
493
+ prompt=prompt,
494
+ negative_prompt=negative_prompt,
495
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
496
+ num_videos_per_prompt=num_videos_per_prompt,
497
+ prompt_embeds=prompt_embeds,
498
+ negative_prompt_embeds=negative_prompt_embeds,
499
+ max_sequence_length=max_sequence_length,
500
+ device=device,
501
+ )
502
+
503
+ transformer_dtype = self.transformer.dtype
504
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
505
+ if negative_prompt_embeds is not None:
506
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
507
+
508
+ # 4. Prepare timesteps
509
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
510
+ timesteps = self.scheduler.timesteps
511
+
512
+ # 5. Prepare latent variables
513
+ num_channels_latents = self.transformer.config.in_channels
514
+ latents = self.prepare_latents(
515
+ batch_size * num_videos_per_prompt,
516
+ num_channels_latents,
517
+ height,
518
+ width,
519
+ num_frames,
520
+ torch.float32,
521
+ device,
522
+ generator,
523
+ latents,
524
+ )
525
+
526
+ # 6. Denoising loop
527
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
528
+ self._num_timesteps = len(timesteps)
529
+
530
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
531
+ for i, t in enumerate(timesteps):
532
+ if self.interrupt:
533
+ continue
534
+
535
+ self._current_timestep = t
536
+ latent_model_input = latents.to(transformer_dtype)
537
+ timestep = t.expand(latents.shape[0])
538
+
539
+ noise_pred = self.transformer(
540
+ hidden_states=latent_model_input,
541
+ timestep=timestep,
542
+ encoder_hidden_states=prompt_embeds,
543
+ attention_kwargs=attention_kwargs,
544
+ return_dict=False,
545
+ )[0]
546
+
547
+ if self.do_classifier_free_guidance:
548
+ noise_pred_uncond = self.transformer(
549
+ hidden_states=latent_model_input,
550
+ timestep=timestep,
551
+ encoder_hidden_states=negative_prompt_embeds,
552
+ attention_kwargs=attention_kwargs,
553
+ return_dict=False,
554
+ )[0]
555
+
556
+ noise_pred_text = noise_pred
557
+ if use_cfg_zero_star:
558
+ positive_flat = noise_pred_text.view(batch_size, -1)
559
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
560
+
561
+ alpha = optimized_scale(positive_flat,negative_flat)
562
+ alpha = alpha.view(batch_size, 1, 1, 1)
563
+
564
+ if (i <= zero_steps) and use_zero_init:
565
+ noise_pred = noise_pred_text*0.
566
+ else:
567
+ noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_text - noise_pred_uncond * alpha)
568
+ else:
569
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
570
+
571
+
572
+ # compute the previous noisy sample x_t -> x_t-1
573
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
574
+
575
+ if callback_on_step_end is not None:
576
+ callback_kwargs = {}
577
+ for k in callback_on_step_end_tensor_inputs:
578
+ callback_kwargs[k] = locals()[k]
579
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
580
+
581
+ latents = callback_outputs.pop("latents", latents)
582
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
583
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
584
+
585
+ # call the callback, if provided
586
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
587
+ progress_bar.update()
588
+
589
+ if XLA_AVAILABLE:
590
+ xm.mark_step()
591
+
592
+ self._current_timestep = None
593
+
594
+ if not output_type == "latent":
595
+ latents = latents.to(self.vae.dtype)
596
+ latents_mean = (
597
+ torch.tensor(self.vae.config.latents_mean)
598
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
599
+ .to(latents.device, latents.dtype)
600
+ )
601
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
602
+ latents.device, latents.dtype
603
+ )
604
+ latents = latents / latents_std + latents_mean
605
+ video = self.vae.decode(latents, return_dict=False)[0]
606
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
607
+ else:
608
+ video = latents
609
+
610
+ # Offload all models
611
+ self.maybe_free_model_hooks()
612
+
613
+ if not return_dict:
614
+ return (video,)
615
+
616
+ return WanPipelineOutput(frames=video)