import torch import torch.nn as nn import torch.cuda.amp as amp import torch.nn.functional as F import math import os import time import numpy as np import random # from flash_attn.flash_attention import FlashAttention class FlashAttentionBlock(nn.Module): def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4): # consider head_dim first, then num_heads num_heads = dim // head_dim if head_dim else num_heads head_dim = dim // num_heads assert num_heads * head_dim == dim super(FlashAttentionBlock, self).__init__() self.dim = dim self.context_dim = context_dim self.num_heads = num_heads self.head_dim = head_dim self.scale = math.pow(head_dim, -0.25) # layers self.norm = nn.GroupNorm(32, dim) self.to_qkv = nn.Conv2d(dim, dim * 3, 1) if context_dim is not None: self.context_kv = nn.Linear(context_dim, dim * 2) self.proj = nn.Conv2d(dim, dim, 1) if self.head_dim <= 128 and (self.head_dim % 8) == 0: new_scale = math.pow(head_dim, -0.5) self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0) # zero out the last layer params nn.init.zeros_(self.proj.weight) # self.apply(self._init_weight) def _init_weight(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.15) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Conv2d): module.weight.data.normal_(mean=0.0, std=0.15) if module.bias is not None: module.bias.data.zero_() def forward(self, x, context=None): r"""x: [B, C, H, W]. context: [B, L, C] or None. """ identity = x b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim # compute query, key, value x = self.norm(x) q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) if context is not None: ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) k = torch.cat([ck, k], dim=-1) v = torch.cat([cv, v], dim=-1) cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device) q = torch.cat([q, cq], dim=-1) qkv = torch.cat([q,k,v], dim=1) origin_dtype = qkv.dtype qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous() out, _ = self.flash_attn(qkv) out.to(origin_dtype) if context is not None: out = out[:, :-4, :, :] out = out.permute(0, 2, 3, 1).reshape(b, c, h, w) # output x = self.proj(out) return x + identity if __name__ == '__main__': batch_size = 8 flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda() x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda() context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda() # context = None flash_net.eval() with amp.autocast(enabled=True): # warm up for i in range(5): y = flash_net(x, context) torch.cuda.synchronize() s1 = time.time() for i in range(10): y = flash_net(x, context) torch.cuda.synchronize() s2 = time.time() print(f'Average cost time {(s2-s1)*1000/10} ms')