Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from xformers.components.attention import Attention, AttentionConfig, register_attention | |
from xformers.components.attention.core import scaled_dot_product_attention | |
class LinformerSelfAttentionConfig(AttentionConfig): | |
seq_len: int # dimension of the input sequence | |
k: Optional[int] # dimension of the internal space | |
class LinformerAttention(Attention): | |
def __init__( | |
self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs | |
): | |
""" | |
Linformer attention mechanism, | |
from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020). | |
The original notation is kept as is. | |
.. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2 | |
""" | |
super().__init__() | |
if k is None: | |
k = seq_len // 4 | |
self.k = k | |
self.E = nn.Linear(seq_len, k, bias=False) | |
self.F = nn.Linear(seq_len, k, bias=False) | |
self.attn_drop = nn.Dropout(dropout, inplace=False) | |
self.seq_len = seq_len | |
# MHA related flags: | |
# kq need to have the same dimension | |
self.requires_same_k_q_dimensions = True | |
# This attention does not support attention masks | |
self.supports_attention_mask = False | |
def forward( | |
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs | |
): | |
# Handle a smaller dimension than expected | |
padding = 0 | |
if q.shape[1] < self.seq_len: | |
padding = self.seq_len - q.shape[1] | |
pad_dims = (0, 0, 0, padding) | |
q = torch.nn.functional.pad(q, pad_dims) | |
k = torch.nn.functional.pad(k, pad_dims) | |
v = torch.nn.functional.pad(v, pad_dims) | |
k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1) | |
v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1) | |
y = scaled_dot_product_attention( | |
q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop | |
) | |
y = self.attn_drop(y) | |
return y[:, :-padding, :] if padding > 0 else y | |