File size: 2,741 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
Replace all attentions from an existing ViT model with a sparse equivalent?
===========================================================================

Let's say you're used to working with a given Transformer based model, and want to experiment with one of the attention mechanisms supported by xFormers.

The following example shows how to do that in a particular example (reusing a reference ViT from pytorch-image-models_), but some aspects will translate just as well
when considering other model sources. In any case, please check the notebooks in the repository_ for a more exhaustive take.


.. code-block:: python

    import timm
    from timm.models.vision_transformer import VisionTransformer
    from xformers.components.attention import ScaledDotProduct
    from xformers.helpers.timm_sparse_attention import TimmSparseAttention
    img_size = 224
    patch_size = 16

    # Get a reference ViT model
    model = VisionTransformer(img_size=img_size, patch_size=patch_size,
                                embed_dim=96, depth=8, num_heads=8, mlp_ratio=3.,
                                qkv_bias=False, norm_layer=nn.LayerNorm).cuda()


    # Define the mask that we want to use
    # We suppose in this snipper that you have a precise mask in mind already
    # but several helpers and examples are proposed in  `xformers.components.attention.attention_patterns`
    my_fancy_mask : torch.Tensor  # This would be for you to define

    # Define a recursive monkey patching function
    def replace_attn_with_xformers_one(module, att_mask):
        module_output = module
        if isinstance(module, timm.models.vision_transformer.Attention):
            qkv = module.qkv
            dim = qkv.weight.shape[1] * module.num_heads
            # Extra parameters can be exposed in TimmSparseAttention, this is a minimal example
            module_output = TimmSparseAttention(dim, module.num_heads, attn_mask=att_mask)
        for name, child in module.named_children():
            module_output.add_module(name, replace_attn_with_xformers_one(child, att_mask))
        del module

        return module_output

    # Now we can just patch our reference model, and get a sparse-aware variation
    model = replace_attn_with_xformers_one(model, my_fancy_mask)

Note that in practice exchanging all the attentions with a sparse alternative may not be a good idea, as the attentions closer to the output are not typically exhibiting a clear sparsity pattern. You can alter `replace_attn_with_xformers_one` above, or replace manually the attentions which would like to sparsify, but not all


.. _pytorch-image-models: https://github.com/rwightman/pytorch-image-models
.. _repository: https://github.com/facebookresearch/xformers