Spaces:
Running
on
Zero
Running
on
Zero
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py | |
from einops import rearrange, repeat | |
# import cv2 | |
# from basicsr.utils import img2tensor, tensor2img | |
import os | |
import math | |
import inspect | |
import numpy as np | |
from dataclasses import dataclass | |
from typing import Optional | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.models.modeling_utils import ModelMixin | |
from diffusers.utils import BaseOutput | |
from diffusers.utils.import_utils import is_xformers_available | |
from diffusers.models.attention_processor import Attention | |
from diffusers.models.attention import FeedForward, AdaLayerNorm | |
class Transformer2DModelOutput(BaseOutput): | |
sample: torch.FloatTensor | |
if is_xformers_available(): | |
import xformers | |
import xformers.ops | |
else: | |
xformers = None | |
class Transformer2DModel(ModelMixin, ConfigMixin): | |
def __init__( | |
self, | |
num_attention_heads: int = 16, | |
attention_head_dim: int = 88, | |
in_channels: Optional[int] = None, | |
num_layers: int = 1, | |
dropout: float = 0.0, | |
norm_num_groups: int = 32, | |
cross_attention_dim: Optional[int] = None, | |
attention_bias: bool = False, | |
sample_size: Optional[int] = None, | |
num_vector_embeds: Optional[int] = None, | |
activation_fn: str = "geglu", | |
num_embeds_ada_norm: Optional[int] = None, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
use_sc_attn: bool = False, | |
use_st_attn: bool = False, | |
updown="mid", | |
layer_id=0, | |
): | |
super().__init__() | |
self.use_linear_projection = use_linear_projection | |
self.num_attention_heads = num_attention_heads | |
self.attention_head_dim = attention_head_dim | |
inner_dim = num_attention_heads * attention_head_dim | |
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` | |
# Define whether input is continuous or discrete depending on configuration | |
self.is_input_continuous = in_channels is not None | |
self.is_input_vectorized = num_vector_embeds is not None | |
if self.is_input_continuous and self.is_input_vectorized: | |
raise ValueError( | |
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" | |
" sure that either `in_channels` or `num_vector_embeds` is None." | |
) | |
elif not self.is_input_continuous and not self.is_input_vectorized: | |
raise ValueError( | |
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" | |
" sure that either `in_channels` or `num_vector_embeds` is not None." | |
) | |
# 2. Define input layers | |
if self.is_input_continuous: | |
self.in_channels = in_channels | |
self.norm = torch.nn.GroupNorm( | |
num_groups=norm_num_groups, | |
num_channels=in_channels, | |
eps=1e-6, | |
affine=True, | |
) | |
if use_linear_projection: | |
self.proj_in = nn.Linear(in_channels, inner_dim) | |
else: | |
self.proj_in = nn.Conv2d( | |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0 | |
) | |
else: | |
raise NotImplementedError | |
# Define transformers blocks | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
BasicTransformerBlock( | |
inner_dim, | |
num_attention_heads, | |
attention_head_dim, | |
dropout=dropout, | |
cross_attention_dim=cross_attention_dim, | |
activation_fn=activation_fn, | |
num_embeds_ada_norm=num_embeds_ada_norm, | |
attention_bias=attention_bias, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
use_sc_attn=use_sc_attn, | |
use_st_attn=False, | |
updown=updown, | |
layer_id=layer_id, | |
) | |
for d in range(num_layers) | |
] | |
) | |
# 4. Define output layers | |
if use_linear_projection: | |
self.proj_out = nn.Linear(in_channels, inner_dim) | |
else: | |
self.proj_out = nn.Conv2d( | |
inner_dim, in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
def forward( | |
self, | |
hidden_states, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
timestep=None, | |
return_dict: bool = True, | |
iter_cur=0, | |
save_kv=True, | |
mode="drag", | |
mask=None, | |
): | |
# Convert encoder_attention_mask to a bias the same way we do for attention_mask | |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: | |
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 | |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
# Input | |
assert ( | |
hidden_states.dim() == 5 | |
), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." | |
video_length = hidden_states.shape[2] | |
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") | |
if encoder_hidden_states is not None: | |
encoder_hidden_states = repeat( | |
encoder_hidden_states, "b n c -> (b f) n c", f=video_length | |
) | |
batch, channel, height, weight = hidden_states.shape | |
residual = hidden_states | |
hidden_states = self.norm(hidden_states) | |
if not self.use_linear_projection: | |
hidden_states = self.proj_in(hidden_states) | |
inner_dim = hidden_states.shape[1] | |
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( | |
batch, height * weight, inner_dim | |
) | |
else: | |
inner_dim = hidden_states.shape[1] | |
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( | |
batch, height * weight, inner_dim | |
) | |
hidden_states = self.proj_in(hidden_states) | |
# Blocks | |
for block in self.transformer_blocks: | |
hidden_states = block( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
timestep=timestep, | |
video_length=video_length, | |
iter_cur=iter_cur, | |
save_kv=save_kv, | |
mode=mode, | |
mask=mask, | |
) | |
# Output | |
if not self.use_linear_projection: | |
hidden_states = ( | |
hidden_states.reshape(batch, height, weight, inner_dim) | |
.permute(0, 3, 1, 2) | |
.contiguous() | |
) | |
hidden_states = self.proj_out(hidden_states) | |
else: | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = ( | |
hidden_states.reshape(batch, height, weight, inner_dim) | |
.permute(0, 3, 1, 2) | |
.contiguous() | |
) | |
output = hidden_states + residual | |
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) | |
if not return_dict: | |
return (output,) | |
return Transformer2DModelOutput(sample=output) | |
class BasicTransformerBlock(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
num_attention_heads: int, | |
attention_head_dim: int, | |
dropout=0.0, | |
cross_attention_dim: Optional[int] = None, | |
activation_fn: str = "geglu", | |
num_embeds_ada_norm: Optional[int] = None, | |
attention_bias: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
use_sc_attn: bool = False, | |
use_st_attn: bool = False, | |
updown="mid", | |
layer_id=0, | |
): | |
super().__init__() | |
self.only_cross_attention = only_cross_attention | |
self.use_ada_layer_norm = num_embeds_ada_norm is not None | |
# Attn with temporal modeling | |
self.use_sc_attn = use_sc_attn | |
self.use_st_attn = use_st_attn | |
attn_type = Attention | |
self.attn1 = attn_type( | |
query_dim=dim, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
dropout=dropout, | |
bias=attention_bias, | |
cross_attention_dim=cross_attention_dim if only_cross_attention else None, | |
upcast_attention=upcast_attention, | |
) # is a self-attention | |
self.attn1.updown = updown | |
self.attn1.layer_id = layer_id | |
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) | |
# Cross-Attn | |
if cross_attention_dim is not None: | |
self.attn2 = Attention( | |
query_dim=dim, | |
cross_attention_dim=cross_attention_dim, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
dropout=dropout, | |
bias=attention_bias, | |
upcast_attention=upcast_attention, | |
) # is self-attn if encoder_hidden_states is none | |
else: | |
self.attn2 = None | |
self.norm1 = ( | |
AdaLayerNorm(dim, num_embeds_ada_norm) | |
if self.use_ada_layer_norm | |
else nn.LayerNorm(dim) | |
) | |
if cross_attention_dim is not None: | |
self.norm2 = ( | |
AdaLayerNorm(dim, num_embeds_ada_norm) | |
if self.use_ada_layer_norm | |
else nn.LayerNorm(dim) | |
) | |
else: | |
self.norm2 = None | |
# 3. Feed-forward | |
self.norm3 = nn.LayerNorm(dim) | |
def get_attn_args(self, attn_layer: nn.Module, attn_kwargs: dict): | |
attn_parameters = set(inspect.signature(attn_layer.processor.__call__).parameters.keys()) | |
unused_kwargs = [ | |
k for k, _ in attn_kwargs.items() if k not in attn_parameters | |
] | |
if len(unused_kwargs) > 0: | |
print( | |
f"Attention kwargs {unused_kwargs} are not expected by {attn_layer.__class__.__name__} and will be ignored." | |
) | |
used_kwargs = {k: w for k, w in attn_kwargs.items() if k in attn_parameters} | |
return used_kwargs | |
def forward( | |
self, | |
hidden_states, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
timestep=None, | |
attention_mask=None, | |
video_length=None, | |
iter_cur=0, | |
save_kv=True, | |
mode="drag", | |
mask=None, | |
): | |
# SparseCausal-Attention | |
norm_hidden_states = ( | |
self.norm1(hidden_states, timestep) | |
if self.use_ada_layer_norm | |
else self.norm1(hidden_states) | |
) | |
attn1_kwargs = self.get_attn_args(self.attn1, | |
{ | |
'video_length': video_length, | |
'iter_cur': iter_cur, | |
'save_kv': save_kv, | |
'mode': mode, | |
'mask': mask, | |
}) | |
hidden_states = ( | |
self.attn1( | |
norm_hidden_states, | |
attention_mask=attention_mask, | |
**attn1_kwargs, | |
) | |
+ hidden_states | |
) | |
if self.attn2 is not None: | |
# Cross-Attention | |
norm_hidden_states = ( | |
self.norm2(hidden_states, timestep) | |
if self.use_ada_layer_norm | |
else self.norm2(hidden_states) | |
) | |
attn2_kwargs = {'iter_cur': -1 if save_kv else iter_cur} | |
attn2_kwargs = self.get_attn_args(self.attn2, attn2_kwargs) | |
hidden_states = ( | |
self.attn2( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=encoder_attention_mask, | |
**attn2_kwargs, | |
) | |
+ hidden_states | |
) | |
# Feed-forward | |
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states | |
return hidden_states | |