PSHuman / mvdiffusion /models_unclip /attn_processors.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
raw
history blame contribute delete
27.6 kB
from typing import Any, Dict, Optional
import torch
from torch import nn
from diffusers.models.attention import Attention
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
import math
import torch.nn.functional as F
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class RowwiseMVAttention(Attention):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
):
processor = XFormersMVAttnProcessor()
self.set_processor(processor)
# print("using xformers attention processor")
class IPCDAttention(Attention):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
):
processor = XFormersIPCDAttnProcessor()
self.set_processor(processor)
# print("using xformers attention processor")
class XFormersMVAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_views=1,
multiview_attention=True,
cd_attention_mid=False
):
# print(num_views)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
height = int(math.sqrt(sequence_length))
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# from yuancheng; here attention_mask is None
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
print('Warning: using group norm, pay attention to use it in row-wise attention')
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key_raw = attn.to_k(encoder_hidden_states)
value_raw = attn.to_v(encoder_hidden_states)
# print('query', query.shape, 'key', key.shape, 'value', value.shape)
# pdb.set_trace()
def transpose(tensor):
tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
return tensor
# print(mvcd_attention)
# import pdb;pdb.set_trace()
if cd_attention_mid:
key = transpose(key_raw)
value = transpose(value_raw)
query = transpose(query)
else:
key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if cd_attention_mid:
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
else:
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class XFormersIPCDAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def process(self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_tasks=2,
num_views=6):
### TODO: num_views
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
height = int(math.sqrt(sequence_length))
height_st = height // 3
height_end = height - height_st
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# from yuancheng; here attention_mask is None
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
assert num_tasks == 2 # only support two tasks now
# ip attn
# hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views)
# body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :]
# print(body_hidden_states.shape, face_hidden_states.shape)
# import pdb;pdb.set_trace()
# hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view)
# hidden_states = rearrange(
# torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1),
# 'b v l c -> (b v) l c')
# face cross attention
# ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1)
# ip_key = attn.to_k_ip(ip_hidden_states)
# ip_value = attn.to_v_ip(ip_hidden_states)
# ip_key = attn.head_to_batch_dim(ip_key).contiguous()
# ip_value = attn.head_to_batch_dim(ip_value).contiguous()
# ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous()
# ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask)
# ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
# ip_hidden_states = attn.to_out_ip[0](ip_hidden_states)
# ip_hidden_states = attn.to_out_ip[1](ip_hidden_states)
# import pdb;pdb.set_trace()
def transpose(tensor):
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
# tensor = rearrange(tensor, "(b v) l c -> b v l c", v=num_views+1)
# body, face = tensor[:, :-1, :], tensor[:, -1:, :] # b,v,l,c; b,1,l,c
# face = face.repeat(1, num_views, 1, 1) # b,v,l,c
# tensor = torch.cat([body, face], dim=2) # b, v, 4hw, c
# tensor = rearrange(tensor, "b v l c -> (b v) l c")
return tensor
key = transpose(key)
value = transpose(value)
query = transpose(query)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2) # bv, hw, c
hidden_states_normal = rearrange(hidden_states_normal, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height)
face_normal = rearrange(hidden_states_normal[:, -1, :, :, :], 'b h w c -> b c h w').detach()
face_normal = rearrange(F.interpolate(face_normal, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c')
hidden_states_normal = hidden_states_normal.clone() # Create a copy of hidden_states_normal
hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_normal
# hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_normal
hidden_states_normal = rearrange(hidden_states_normal, "b v h w c -> (b v) (h w) c")
hidden_states_color = rearrange(hidden_states_color, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height)
face_color = rearrange(hidden_states_color[:, -1, :, :, :], 'b h w c -> b c h w').detach()
face_color = rearrange(F.interpolate(face_color, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c')
hidden_states_color = hidden_states_color.clone() # Create a copy of hidden_states_color
hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_color
# hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_color
hidden_states_color = rearrange(hidden_states_color, "b v h w c -> (b v) (h w) c")
hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_tasks=2,
):
hidden_states = self.process(attn, hidden_states, encoder_hidden_states, attention_mask, temb, num_tasks)
# hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c')
# body_hidden_states, head_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1:, :, :]
# import pdb;pdb.set_trace()
# hidden_states = body_hidden_states + attn.ip_scale * head_hidden_states.detach().repeat(1, views, 1, 1)
# hidden_states = rearrange(
# torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states], dim=1),
# 'b v l c -> (b v) l c')
return hidden_states
class IPCrossAttn(Attention):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self,
query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, ip_scale=1.0):
super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention)
self.ip_scale = ip_scale
# self.num_tokens = num_tokens
# self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False)
# self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False)
# self.to_out_ip = nn.ModuleList([])
# self.to_out_ip.append(nn.Linear(self.inner_dim, self.inner_dim, bias=bias))
# self.to_out_ip.append(nn.Dropout(dropout))
# nn.init.zeros_(self.to_k_ip.weight.data)
# nn.init.zeros_(self.to_v_ip.weight.data)
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
):
processor = XFormersIPCrossAttnProcessor()
self.set_processor(processor)
class XFormersIPCrossAttnProcessor:
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_views=1
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
# ip attn
# hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views)
# body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :]
# print(body_hidden_states.shape, face_hidden_states.shape)
# import pdb;pdb.set_trace()
# hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view)
# hidden_states = rearrange(
# torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1),
# 'b v l c -> (b v) l c')
# face cross attention
# ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1)
# ip_key = attn.to_k_ip(ip_hidden_states)
# ip_value = attn.to_v_ip(ip_hidden_states)
# ip_key = attn.head_to_batch_dim(ip_key).contiguous()
# ip_value = attn.head_to_batch_dim(ip_value).contiguous()
# ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous()
# ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask)
# ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
# ip_hidden_states = attn.to_out_ip[0](ip_hidden_states)
# ip_hidden_states = attn.to_out_ip[1](ip_hidden_states)
# import pdb;pdb.set_trace()
# body_hidden_states = body_hidden_states + attn.ip_scale * ip_hidden_states
# hidden_states = rearrange(
# torch.cat([rearrange(body_hidden_states, '(b v) l c -> b v l c', v=num_views-1), face_hidden_states.unsqueeze(1)], dim=1),
# 'b v l c -> (b v) l c')
# import pdb;pdb.set_trace()
#
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# TODO: region control
# region control
# if len(region_control.prompt_image_conditioning) == 1:
# region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
# if region_mask is not None:
# h, w = region_mask.shape[:2]
# ratio = (h * w / query.shape[1]) ** 0.5
# mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
# else:
# mask = torch.ones_like(ip_hidden_states)
# ip_hidden_states = ip_hidden_states * mask
return hidden_states
class RowwiseMVProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_views=1,
cd_attention_mid=False
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
height = int(math.sqrt(sequence_length))
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# print('query', query.shape, 'key', key.shape, 'value', value.shape)
#([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
# pdb.set_trace()
# multi-view self-attention
def transpose(tensor):
tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
return tensor
if cd_attention_mid:
key = transpose(key)
value = transpose(value)
query = transpose(query)
else:
key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if cd_attention_mid:
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
else:
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CDAttention(Attention):
# def __init__(self, ip_scale,
# query_dim, heads, dim_head, dropout, bias, cross_attention_dim, upcast_attention, processor):
# super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, processor=processor)
# self.ip_scale = ip_scale
# self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False)
# self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False)
# nn.init.zeros_(self.to_k_ip.weight.data)
# nn.init.zeros_(self.to_v_ip.weight.data)
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
):
processor = XFormersCDAttnProcessor()
self.set_processor(processor)
# print("using xformers attention processor")
class XFormersCDAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_tasks=2
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
assert num_tasks == 2 # only support two tasks now
def transpose(tensor):
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
return tensor
key = transpose(key)
value = transpose(value)
query = transpose(query)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states