Spaces:
Runtime error
Runtime error
File size: 6,263 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytorch_lightning as pl
import torch
from pl_bolts.datamodules import CIFAR10DataModule
from torch import nn
from torchmetrics import Accuracy
from examples.cifar_ViT import Classifier, VisionTransformer
from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
BasicLayerConfig,
get_hierarchical_configuration,
)
# This is very close to the cifarViT example, and reuses a lot of the training code, only the model part is different.
# There are many ways one can use xformers to write down a MetaFormer, for instance by
# picking up the parts from `xformers.components` and implementing the model explicitly,
# or by patching another existing ViT-like implementation.
# This example takes another approach, as we define the whole model configuration in one go (dict structure)
# and then use the xformers factory to generate the model. This obfuscates a lot of the model building
# (though you can inspect the resulting implementation), but makes it trivial to do some hyperparameter search
class MetaVisionTransformer(VisionTransformer):
def __init__(
self,
steps,
learning_rate=5e-3,
betas=(0.9, 0.99),
weight_decay=0.03,
image_size=32,
num_classes=10,
dim=384,
attention="scaled_dot_product",
feedforward="MLP",
residual_norm_style="pre",
use_rotary_embeddings=True,
linear_warmup_ratio=0.1,
classifier=Classifier.GAP,
):
super(VisionTransformer, self).__init__()
# all the inputs are saved under self.hparams (hyperparams)
self.save_hyperparameters()
# Generate the skeleton of our hierarchical Transformer
# - This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32)
# - Please note that this does not match the L1 configuration in the paper, as this would correspond to repeated
# layers. CIFAR pictures are too small for this config to be directly meaningful (although that would run)
# - Any other related config would work, and the attention mechanisms don't have to be the same across layers
base_hierarchical_configs = [
BasicLayerConfig(
embedding=64,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 4,
feedforward=feedforward,
repeat_layer=1,
),
BasicLayerConfig(
embedding=128,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 16,
feedforward=feedforward,
repeat_layer=1,
),
BasicLayerConfig(
embedding=320,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
feedforward=feedforward,
repeat_layer=1,
),
BasicLayerConfig(
embedding=512,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
feedforward=feedforward,
repeat_layer=1,
),
]
# Fill in the gaps in the config
xformer_config = get_hierarchical_configuration(
base_hierarchical_configs,
residual_norm_style=residual_norm_style,
use_rotary_embeddings=use_rotary_embeddings,
mlp_multiplier=4,
dim_head=32,
)
# Now instantiate the metaformer trunk
config = xFormerConfig(xformer_config)
config.weight_init = "moco"
print(config)
self.trunk = xFormer.from_config(config)
print(self.trunk)
# The classifier head
dim = base_hierarchical_configs[-1].embedding
self.ln = nn.LayerNorm(dim)
self.head = nn.Linear(dim, num_classes)
self.criterion = torch.nn.CrossEntropyLoss()
self.val_accuracy = Accuracy()
def forward(self, x):
x = self.trunk(x)
x = self.ln(x)
x = x.mean(dim=1) # mean over sequence len
x = self.head(x)
return x
if __name__ == "__main__":
pl.seed_everything(42)
# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 768
BATCH = 256 # lower if not enough GPU memory
MAX_EPOCHS = 50
NUM_WORKERS = 4
GPUS = 1
torch.cuda.manual_seed_all(42)
torch.manual_seed(42)
# We'll use a datamodule here, which already handles dataset/dataloader/sampler
# - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# for a full tutorial
# - Please note that default transforms are being used
dm = CIFAR10DataModule(
data_dir="data",
batch_size=BATCH,
num_workers=NUM_WORKERS,
pin_memory=True,
)
image_size = dm.size(-1) # 32 for CIFAR
num_classes = dm.num_classes # 10 for CIFAR
# compute total number of steps
batch_size = BATCH * GPUS
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS
lm = MetaVisionTransformer(
steps=steps,
image_size=image_size,
num_classes=num_classes,
attention="scaled_dot_product",
residual_norm_style="pre",
feedforward="MLP",
use_rotary_embeddings=True,
)
trainer = pl.Trainer(
gpus=GPUS,
max_epochs=MAX_EPOCHS,
precision=16,
accumulate_grad_batches=REF_BATCH // BATCH,
)
trainer.fit(lm, dm)
# check the training
trainer.test(lm, datamodule=dm)
|