Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
from dataclasses import dataclass | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from xformers.components.attention import Attention, AttentionConfig, register_attention | |
class PoolingAttentionConfig(AttentionConfig): | |
pool_size: int # dimension of the input sequence | |
stride: Optional[int] # dimension of the internal space | |
padding: Optional[int] | |
class Pooling(Attention): | |
def __init__( | |
self, | |
pool_size: int = 3, | |
stride: int = 1, | |
padding: Optional[int] = None, | |
*_, | |
**__, | |
): | |
""" | |
Pooling token mixing mechanism, as proposed in | |
`Metaformer is actually what you need for vision`_, Yu et al (2021). | |
The original notation is kept as is. | |
.. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf | |
""" | |
super().__init__() | |
padding = padding if padding is not None else pool_size // 2 | |
self.pool = nn.AvgPool2d( | |
pool_size, | |
stride=stride, | |
padding=pool_size // 2, | |
count_include_pad=False, | |
) | |
# MHA related flags: | |
# kq need to have the same dimension | |
self.requires_same_k_q_dimensions = False | |
# This attention does not support attention masks | |
self.supports_attention_mask = False | |
# This "attention" (token mixing) skips the multihead attention altogether | |
self.requires_skip_multi_head = True | |
self.requires_input_projection = False | |
# This operator does not really handle q,k,v | |
self.requires_same_k_q_dimensions = True | |
# This attention requires the 2d structure out of the context, | |
# implictly assumed to be a squared length | |
self.requires_squared_context = True | |
def forward(self, q: torch.Tensor, *_, **__): | |
# Expose the 2D token structure | |
B, HW, C = q.shape | |
H = int(math.sqrt(HW)) | |
assert H * H == HW | |
q = q.transpose(-2, -1).reshape(B, C, H, H) | |
# 2D pool | |
x_pool = self.pool(q) - q # compensate for the residual path | |
# Get back to B HW C | |
return x_pool.flatten(2, 3).transpose(-2, -1) | |