Diffusers documentation
Torch2.0 support in Diffusers
Torch2.0 support in Diffusers
Starting from version 0.13.0
, Diffusers supports the latest optimization from the upcoming PyTorch 2.0 release. These include:
- Support for native flash and memory-efficient attention without any extra dependencies.
- torch.compile support for compiling individual models for extra performance boost.
Installation
To benefit from the native efficient attention and `torch.compile`, we will need to install the nightly version of PyTorch as the stable version is yet to be released. The first step is to install CUDA11.7 or CUDA11.8, as torch2.0 does not support the previous versions. Once CUDA is installed, torch nightly can be installed using:pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117
Using efficient attention and torch.compile.
Efficient Attention
Efficient attention is implemented via the
torch.nn.functional.scaled_dot_product_attention
function, which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is the same as thememory_efficient_attention
from xFormers but built natively into PyTorch.Efficient attention will be enabled by default in Diffusers if torch2.0 is installed and if
torch.nn.functional.scaled_dot_product_attention
is available. To use it, you can install torch2.0 as suggested above and use the pipeline. For example:import torch from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0]
If you want to enable it explicitly (which is not required), you can do so as shown below.
import torch from diffusers import StableDiffusionPipeline from diffusers.models.cross_attention import AttnProcessor2_0 pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") pipe.unet.set_attn_processor(AttnProcessor2_0()) prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0]
This should be as fast and memory efficient as
xFormers
.
torch.compile
To get an additional speedup, we can use the new
torch.compile
feature. To do so, we wrap ourunet
withtorch.compile
. For more information and different options, refer to the torch compile docs.import torch from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to( "cuda" ) pipe.unet = torch.compile(pipe.unet) batch_size = 10 prompt = "A photo of an astronaut riding a horse on marse." images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
Depending on the type of GPU it can give between 2-9% speed-up over efficient attention. But note that as of now the speed-up is mostly noticeable on the more recent GPU architectures, such as in the A100.
Note that compilation will also take some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times.
Benchmark
We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, torch.nn.functional.scaled_dot_product_attention
and torch.compile+torch.nn.functional.scaled_dot_product_attention
.
For the benchmark we used the the stable-diffusion-v1-4 model with 50 steps. xFormers
benchmark is done using the torch==1.13.1
version. The table below summarizes the result that we got.
The Speed over xformers
columns denotes the speed-up gained over xFormers
using the torch.compile+torch.nn.functional.scaled_dot_product_attention
.
FP16 benchmark
The table below shows the benchmark results for inference using fp16
. As we can see, torch.nn.functional.scaled_dot_product_attention
is as fast as xFormers
(sometimes slightly faster/slower) on all the GPUs we tested.
And using torch.compile
gives further speed-up up to 10% over xFormers
, but it’s mostly noticeable on the A100 GPU.
The time reported is in seconds.
GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) |
---|---|---|---|---|---|---|
A100 | 10 | 12.02 | 8.7 | 8.79 | 7.89 | 9.31 |
A100 | 16 | 18.95 | 13.57 | 13.67 | 12.25 | 9.73 |
A100 | 32 (1) | OOM | 26.56 | 26.68 | 24.08 | 9.34 |
A100 | 64(2) | 52.51 | 53.03 | 47.81 | 8.95 | |
A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 |
A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 |
A10 | 10 | 33.69 | 23.53 | 24.19 | 22.52 | 4.29 |
A10 | 16 | OOM | 37.55 | 38.31 | 36.81 | 1.97 |
A10 | 32 (1) | 77.19 | 78.43 | 76.64 | 0.71 | |
A10 | 64 (1) | 173.59 | 158.99 | 155.14 | 10.63 | |
T4 | 4 | 38.81 | 30.09 | 29.74 | 27.55 | 8.44 |
T4 | 8 | OOM | 55.71 | 55.99 | 53.85 | 3.34 |
T4 | 10 | OOM | 68.96 | 69.86 | 65.35 | 5.23 |
T4 | 16 | OOM | 111.47 | 113.26 | 106.93 | 4.07 |
V100 | 4 | 9.84 | 8.16 | 8.09 | 7.65 | 6.25 |
V100 | 8 | OOM | 15.62 | 15.44 | 14.59 | 6.59 |
V100 | 10 | OOM | 19.52 | 19.28 | 18.18 | 6.86 |
V100 | 16 | OOM | 30.29 | 29.84 | 28.22 | 6.83 |
3090 | 4 | 10.04 | 7.82 | 7.89 | 7.47 | 4.48 |
3090 | 8 | 19.27 | 14.97 | 15.04 | 14.22 | 5.01 |
3090 | 10 | 24.08 | 18.7 | 18.7 | 17.69 | 5.40 |
3090 | 16 | OOM | 29.06 | 29.06 | 28.2 | 2.96 |
3090 | 32 (1) | 58.05 | 58 | 54.88 | 5.46 | |
3090 | 64 (1) | 126.54 | 126.03 | 117.33 | 7.28 | |
3090 Ti | 4 | 9.07 | 7.14 | 7.15 | 6.81 | 4.62 |
3090 Ti | 8 | 17.51 | 13.65 | 13.72 | 12.99 | 4.84 |
3090 Ti | 10 (2) | 21.79 | 16.85 | 16.93 | 16.02 | 4.93 |
3090 Ti | 16 | OOM | 26.1 | 26.28 | 25.46 | 2.45 |
3090 Ti | 32 (1) | 51.78 | 52.04 | 49.15 | 5.08 | |
3090 Ti | 64 (1) | 112.02 | 112.33 | 103.91 | 7.24 |
FP32 benchmark
The table below shows the benchmark results for inference using fp32
. As we can see, torch.nn.functional.scaled_dot_product_attention
is as fast as xFormers
(sometimes slightly faster/slower) on all the GPUs we tested.
Using torch.compile
with efficient attention gives up to 18% performance improvement over xFormers
in Ampere cards, and up to 20% over vanilla attention.
GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) |
---|---|---|---|---|---|---|---|
A100 | 4 | 16.56 | 12.42 | 12.2 | 11.84 | 4.67 | 28.50 |
A100 | 10 | OOM | 29.93 | 29.44 | 28.5 | 4.78 | |
A100 | 16 | 47.08 | 46.27 | 44.8 | 4.84 | ||
A100 | 32 | 92.89 | 91.34 | 88.35 | 4.89 | ||
A100 | 64 | 185.3 | 182.71 | 176.48 | 4.76 | ||
A10 | 1 | 10.59 | 8.81 | 7.51 | 7.35 | 16.57 | 30.59 |
A10 | 4 | 34.77 | 27.63 | 22.77 | 22.07 | 20.12 | 36.53 |
A10 | 8 | 56.19 | 43.53 | 43.86 | 21.94 | ||
A10 | 16 | 116.49 | 88.56 | 86.64 | 25.62 | ||
A10 | 32 | 221.95 | 175.74 | 168.18 | 24.23 | ||
A10 | 48 | 333.23 | 264.84 | 20.52 | |||
T4 | 1 | 28.2 | 24.49 | 23.93 | 23.56 | 3.80 | 16.45 |
T4 | 2 | 52.77 | 45.7 | 45.88 | 45.06 | 1.40 | 14.61 |
T4 | 4 | OOM | 85.72 | 85.78 | 84.48 | 1.45 | |
T4 | 8 | 149.64 | 150.75 | 148.4 | 0.83 | ||
V100 | 1 | 7.4 | 6.84 | 6.8 | 6.66 | 2.63 | 10.00 |
V100 | 2 | 13.85 | 12.81 | 12.66 | 12.35 | 3.59 | 10.83 |
V100 | 4 | OOM | 25.73 | 25.31 | 24.78 | 3.69 | |
V100 | 8 | 43.95 | 43.37 | 42.25 | 3.87 | ||
V100 | 16 | 84.99 | 84.73 | 82.55 | 2.87 | ||
3090 | 1 | 7.09 | 6.78 | 6.11 | 6.03 | 11.06 | 14.95 |
3090 | 4 | 22.69 | 21.45 | 18.67 | 18.09 | 15.66 | 20.27 |
3090 | 8 (2) | 42.59 | 36.75 | 35.59 | 16.44 | ||
3090 | 16 | 85.35 | 72.37 | 70.25 | 17.69 | ||
3090 | 32 (1) | 162.05 | 138.99 | 134.53 | 16.98 | ||
3090 | 48 | 241.91 | 207.75 | 14.12 | |||
3090 Ti | 1 | 6.45 | 6.19 | 5.64 | 5.49 | 11.31 | 14.88 |
3090 Ti | 4 | 20.32 | 19.31 | 16.9 | 16.37 | 15.23 | 19.44 |
3090 Ti | 8 (2) | 37.93 | 33.05 | 31.99 | 15.66 | ||
3090 Ti | 16 | 75.37 | 65.25 | 64.32 | 14.66 | ||
3090 Ti | 32 (1) | 142.55 | 124.44 | 120.74 | 15.30 | ||
3090 Ti | 48 | 213.19 | 186.55 | 12.50 | |||
4090 | 1 | 5.54 | 4.99 | 4.51 | |||
4090 | 4 | 13.67 | 11.4 | 10.3 | |||
4090 | 8 (2) | 19.79 | 17.13 | ||||
4090 | 16 | 38.62 | 33.14 | ||||
4090 | 32 (1) | 76.57 | 65.96 | ||||
4090 | 48 | 114.44 | 98.78 |
(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665 This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64
For more details about how this benchmark was run, please refer to this PR.