|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
from typing import Dict, List, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ...models.attention_processor import ( |
|
Attention, |
|
AttentionProcessor, |
|
PAGCFGIdentitySelfAttnProcessor2_0, |
|
PAGIdentitySelfAttnProcessor2_0, |
|
) |
|
from ...utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class PAGMixin: |
|
r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1).""" |
|
|
|
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): |
|
r""" |
|
Set the attention processor for the PAG layers. |
|
""" |
|
pag_attn_processors = self._pag_attn_processors |
|
if pag_attn_processors is None: |
|
raise ValueError( |
|
"No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." |
|
) |
|
|
|
pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] |
|
|
|
if hasattr(self, "unet"): |
|
model: nn.Module = self.unet |
|
else: |
|
model: nn.Module = self.transformer |
|
|
|
def is_self_attn(module: nn.Module) -> bool: |
|
r""" |
|
Check if the module is self-attention module based on its name. |
|
""" |
|
return isinstance(module, Attention) and not module.is_cross_attention |
|
|
|
def is_fake_integral_match(layer_id, name): |
|
layer_id = layer_id.split(".")[-1] |
|
name = name.split(".")[-1] |
|
return layer_id.isnumeric() and name.isnumeric() and layer_id == name |
|
|
|
for layer_id in pag_applied_layers: |
|
|
|
target_modules = [] |
|
|
|
for name, module in model.named_modules(): |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
is_self_attn(module) |
|
and re.search(layer_id, name) is not None |
|
and not is_fake_integral_match(layer_id, name) |
|
): |
|
logger.debug(f"Applying PAG to layer: {name}") |
|
target_modules.append(module) |
|
|
|
if len(target_modules) == 0: |
|
raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") |
|
|
|
for module in target_modules: |
|
module.processor = pag_attn_proc |
|
|
|
def _get_pag_scale(self, t): |
|
r""" |
|
Get the scale factor for the perturbed attention guidance at timestep `t`. |
|
""" |
|
|
|
if self.do_pag_adaptive_scaling: |
|
signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) |
|
if signal_scale < 0: |
|
signal_scale = 0 |
|
return signal_scale |
|
else: |
|
return self.pag_scale |
|
|
|
def _apply_perturbed_attention_guidance( |
|
self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False |
|
): |
|
r""" |
|
Apply perturbed attention guidance to the noise prediction. |
|
|
|
Args: |
|
noise_pred (torch.Tensor): The noise prediction tensor. |
|
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. |
|
guidance_scale (float): The scale factor for the guidance term. |
|
t (int): The current time step. |
|
return_pred_text (bool): Whether to return the text noise prediction. |
|
|
|
Returns: |
|
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying |
|
perturbed attention guidance and the text noise prediction. |
|
""" |
|
pag_scale = self._get_pag_scale(t) |
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) |
|
noise_pred = ( |
|
noise_pred_uncond |
|
+ guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
+ pag_scale * (noise_pred_text - noise_pred_perturb) |
|
) |
|
else: |
|
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) |
|
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) |
|
if return_pred_text: |
|
return noise_pred, noise_pred_text |
|
return noise_pred |
|
|
|
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): |
|
""" |
|
Prepares the perturbed attention guidance for the PAG model. |
|
|
|
Args: |
|
cond (torch.Tensor): The conditional input tensor. |
|
uncond (torch.Tensor): The unconditional input tensor. |
|
do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. |
|
|
|
Returns: |
|
torch.Tensor: The prepared perturbed attention guidance tensor. |
|
""" |
|
|
|
cond = torch.cat([cond] * 2, dim=0) |
|
|
|
if do_classifier_free_guidance: |
|
cond = torch.cat([uncond, cond], dim=0) |
|
return cond |
|
|
|
def set_pag_applied_layers( |
|
self, |
|
pag_applied_layers: Union[str, List[str]], |
|
pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( |
|
PAGCFGIdentitySelfAttnProcessor2_0(), |
|
PAGIdentitySelfAttnProcessor2_0(), |
|
), |
|
): |
|
r""" |
|
Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. |
|
|
|
Args: |
|
pag_applied_layers (`str` or `List[str]`): |
|
One or more strings identifying the layer names, or a simple regex for matching multiple layers, where |
|
PAG is to be applied. A few ways of expected usage are as follows: |
|
- Single layers specified as - "blocks.{layer_index}" |
|
- Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] |
|
- Multiple layers as a block name - "mid" |
|
- Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" |
|
pag_attn_processors: |
|
(`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), |
|
PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention |
|
processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second |
|
attention processor is for PAG with CFG disabled (unconditional only). |
|
""" |
|
|
|
if not hasattr(self, "_pag_attn_processors"): |
|
self._pag_attn_processors = None |
|
|
|
if not isinstance(pag_applied_layers, list): |
|
pag_applied_layers = [pag_applied_layers] |
|
if pag_attn_processors is not None: |
|
if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: |
|
raise ValueError("Expected a tuple of two attention processors") |
|
|
|
for i in range(len(pag_applied_layers)): |
|
if not isinstance(pag_applied_layers[i], str): |
|
raise ValueError( |
|
f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" |
|
) |
|
|
|
self.pag_applied_layers = pag_applied_layers |
|
self._pag_attn_processors = pag_attn_processors |
|
|
|
@property |
|
def pag_scale(self) -> float: |
|
r"""Get the scale factor for the perturbed attention guidance.""" |
|
return self._pag_scale |
|
|
|
@property |
|
def pag_adaptive_scale(self) -> float: |
|
r"""Get the adaptive scale factor for the perturbed attention guidance.""" |
|
return self._pag_adaptive_scale |
|
|
|
@property |
|
def do_pag_adaptive_scaling(self) -> bool: |
|
r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" |
|
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 |
|
|
|
@property |
|
def do_perturbed_attention_guidance(self) -> bool: |
|
r"""Check if the perturbed attention guidance is enabled.""" |
|
return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 |
|
|
|
@property |
|
def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: |
|
r""" |
|
Returns: |
|
`dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model |
|
with the key as the name of the layer. |
|
""" |
|
|
|
if self._pag_attn_processors is None: |
|
return {} |
|
|
|
valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} |
|
|
|
processors = {} |
|
|
|
|
|
if hasattr(self, "unet"): |
|
denoiser_module = self.unet |
|
elif hasattr(self, "transformer"): |
|
denoiser_module = self.transformer |
|
else: |
|
raise ValueError("No denoiser module found.") |
|
|
|
for name, proc in denoiser_module.attn_processors.items(): |
|
if proc.__class__ in valid_attn_processors: |
|
processors[name] = proc |
|
|
|
return processors |
|
|