File size: 6,263 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import pytorch_lightning as pl
import torch
from pl_bolts.datamodules import CIFAR10DataModule
from torch import nn
from torchmetrics import Accuracy

from examples.cifar_ViT import Classifier, VisionTransformer
from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
    BasicLayerConfig,
    get_hierarchical_configuration,
)

# This is very close to the cifarViT example, and reuses a lot of the training code, only the model part is different.
# There are many ways one can use xformers to write down a MetaFormer, for instance by
# picking up the parts from `xformers.components` and implementing the model explicitly,
# or by patching another existing ViT-like implementation.

# This example takes another approach, as we define the whole model configuration in one go (dict structure)
# and then use the xformers factory to generate the model. This obfuscates a lot of the model building
# (though you can inspect the resulting implementation), but makes it trivial to do some hyperparameter search


class MetaVisionTransformer(VisionTransformer):
    def __init__(
        self,
        steps,
        learning_rate=5e-3,
        betas=(0.9, 0.99),
        weight_decay=0.03,
        image_size=32,
        num_classes=10,
        dim=384,
        attention="scaled_dot_product",
        feedforward="MLP",
        residual_norm_style="pre",
        use_rotary_embeddings=True,
        linear_warmup_ratio=0.1,
        classifier=Classifier.GAP,
    ):

        super(VisionTransformer, self).__init__()

        # all the inputs are saved under self.hparams (hyperparams)
        self.save_hyperparameters()

        # Generate the skeleton of our hierarchical Transformer
        # - This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32)
        # - Please note that this does not match the L1 configuration in the paper, as this would correspond to repeated
        #   layers. CIFAR pictures are too small for this config to be directly meaningful (although that would run)
        # - Any other related config would work, and the attention mechanisms don't have to be the same across layers
        base_hierarchical_configs = [
            BasicLayerConfig(
                embedding=64,
                attention_mechanism=attention,
                patch_size=3,
                stride=2,
                padding=1,
                seq_len=image_size * image_size // 4,
                feedforward=feedforward,
                repeat_layer=1,
            ),
            BasicLayerConfig(
                embedding=128,
                attention_mechanism=attention,
                patch_size=3,
                stride=2,
                padding=1,
                seq_len=image_size * image_size // 16,
                feedforward=feedforward,
                repeat_layer=1,
            ),
            BasicLayerConfig(
                embedding=320,
                attention_mechanism=attention,
                patch_size=3,
                stride=2,
                padding=1,
                seq_len=image_size * image_size // 64,
                feedforward=feedforward,
                repeat_layer=1,
            ),
            BasicLayerConfig(
                embedding=512,
                attention_mechanism=attention,
                patch_size=3,
                stride=2,
                padding=1,
                seq_len=image_size * image_size // 256,
                feedforward=feedforward,
                repeat_layer=1,
            ),
        ]

        # Fill in the gaps in the config
        xformer_config = get_hierarchical_configuration(
            base_hierarchical_configs,
            residual_norm_style=residual_norm_style,
            use_rotary_embeddings=use_rotary_embeddings,
            mlp_multiplier=4,
            dim_head=32,
        )

        # Now instantiate the metaformer trunk
        config = xFormerConfig(xformer_config)
        config.weight_init = "moco"

        print(config)
        self.trunk = xFormer.from_config(config)
        print(self.trunk)

        # The classifier head
        dim = base_hierarchical_configs[-1].embedding
        self.ln = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.val_accuracy = Accuracy()

    def forward(self, x):
        x = self.trunk(x)
        x = self.ln(x)

        x = x.mean(dim=1)  # mean over sequence len
        x = self.head(x)
        return x


if __name__ == "__main__":
    pl.seed_everything(42)

    # Adjust batch depending on the available memory on your machine.
    # You can also use reversible layers to save memory
    REF_BATCH = 768
    BATCH = 256  # lower if not enough GPU memory

    MAX_EPOCHS = 50
    NUM_WORKERS = 4
    GPUS = 1

    torch.cuda.manual_seed_all(42)
    torch.manual_seed(42)

    # We'll use a datamodule here, which already handles dataset/dataloader/sampler
    # - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
    # for a full tutorial
    # - Please note that default transforms are being used
    dm = CIFAR10DataModule(
        data_dir="data",
        batch_size=BATCH,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    image_size = dm.size(-1)  # 32 for CIFAR
    num_classes = dm.num_classes  # 10 for CIFAR

    # compute total number of steps
    batch_size = BATCH * GPUS
    steps = dm.num_samples // REF_BATCH * MAX_EPOCHS
    lm = MetaVisionTransformer(
        steps=steps,
        image_size=image_size,
        num_classes=num_classes,
        attention="scaled_dot_product",
        residual_norm_style="pre",
        feedforward="MLP",
        use_rotary_embeddings=True,
    )
    trainer = pl.Trainer(
        gpus=GPUS,
        max_epochs=MAX_EPOCHS,
        precision=16,
        accumulate_grad_batches=REF_BATCH // BATCH,
    )
    trainer.fit(lm, dm)

    # check the training
    trainer.test(lm, datamodule=dm)