Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pytorch_lightning as pl | |
import torch | |
from pl_bolts.datamodules import CIFAR10DataModule | |
from torch import nn | |
from torchmetrics import Accuracy | |
from examples.cifar_ViT import Classifier, VisionTransformer | |
from xformers.factory import xFormer, xFormerConfig | |
from xformers.helpers.hierarchical_configs import ( | |
BasicLayerConfig, | |
get_hierarchical_configuration, | |
) | |
# This is very close to the cifarViT example, and reuses a lot of the training code, only the model part is different. | |
# There are many ways one can use xformers to write down a MetaFormer, for instance by | |
# picking up the parts from `xformers.components` and implementing the model explicitly, | |
# or by patching another existing ViT-like implementation. | |
# This example takes another approach, as we define the whole model configuration in one go (dict structure) | |
# and then use the xformers factory to generate the model. This obfuscates a lot of the model building | |
# (though you can inspect the resulting implementation), but makes it trivial to do some hyperparameter search | |
class MetaVisionTransformer(VisionTransformer): | |
def __init__( | |
self, | |
steps, | |
learning_rate=5e-3, | |
betas=(0.9, 0.99), | |
weight_decay=0.03, | |
image_size=32, | |
num_classes=10, | |
dim=384, | |
attention="scaled_dot_product", | |
feedforward="MLP", | |
residual_norm_style="pre", | |
use_rotary_embeddings=True, | |
linear_warmup_ratio=0.1, | |
classifier=Classifier.GAP, | |
): | |
super(VisionTransformer, self).__init__() | |
# all the inputs are saved under self.hparams (hyperparams) | |
self.save_hyperparameters() | |
# Generate the skeleton of our hierarchical Transformer | |
# - This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32) | |
# - Please note that this does not match the L1 configuration in the paper, as this would correspond to repeated | |
# layers. CIFAR pictures are too small for this config to be directly meaningful (although that would run) | |
# - Any other related config would work, and the attention mechanisms don't have to be the same across layers | |
base_hierarchical_configs = [ | |
BasicLayerConfig( | |
embedding=64, | |
attention_mechanism=attention, | |
patch_size=3, | |
stride=2, | |
padding=1, | |
seq_len=image_size * image_size // 4, | |
feedforward=feedforward, | |
repeat_layer=1, | |
), | |
BasicLayerConfig( | |
embedding=128, | |
attention_mechanism=attention, | |
patch_size=3, | |
stride=2, | |
padding=1, | |
seq_len=image_size * image_size // 16, | |
feedforward=feedforward, | |
repeat_layer=1, | |
), | |
BasicLayerConfig( | |
embedding=320, | |
attention_mechanism=attention, | |
patch_size=3, | |
stride=2, | |
padding=1, | |
seq_len=image_size * image_size // 64, | |
feedforward=feedforward, | |
repeat_layer=1, | |
), | |
BasicLayerConfig( | |
embedding=512, | |
attention_mechanism=attention, | |
patch_size=3, | |
stride=2, | |
padding=1, | |
seq_len=image_size * image_size // 256, | |
feedforward=feedforward, | |
repeat_layer=1, | |
), | |
] | |
# Fill in the gaps in the config | |
xformer_config = get_hierarchical_configuration( | |
base_hierarchical_configs, | |
residual_norm_style=residual_norm_style, | |
use_rotary_embeddings=use_rotary_embeddings, | |
mlp_multiplier=4, | |
dim_head=32, | |
) | |
# Now instantiate the metaformer trunk | |
config = xFormerConfig(xformer_config) | |
config.weight_init = "moco" | |
print(config) | |
self.trunk = xFormer.from_config(config) | |
print(self.trunk) | |
# The classifier head | |
dim = base_hierarchical_configs[-1].embedding | |
self.ln = nn.LayerNorm(dim) | |
self.head = nn.Linear(dim, num_classes) | |
self.criterion = torch.nn.CrossEntropyLoss() | |
self.val_accuracy = Accuracy() | |
def forward(self, x): | |
x = self.trunk(x) | |
x = self.ln(x) | |
x = x.mean(dim=1) # mean over sequence len | |
x = self.head(x) | |
return x | |
if __name__ == "__main__": | |
pl.seed_everything(42) | |
# Adjust batch depending on the available memory on your machine. | |
# You can also use reversible layers to save memory | |
REF_BATCH = 768 | |
BATCH = 256 # lower if not enough GPU memory | |
MAX_EPOCHS = 50 | |
NUM_WORKERS = 4 | |
GPUS = 1 | |
torch.cuda.manual_seed_all(42) | |
torch.manual_seed(42) | |
# We'll use a datamodule here, which already handles dataset/dataloader/sampler | |
# - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | |
# for a full tutorial | |
# - Please note that default transforms are being used | |
dm = CIFAR10DataModule( | |
data_dir="data", | |
batch_size=BATCH, | |
num_workers=NUM_WORKERS, | |
pin_memory=True, | |
) | |
image_size = dm.size(-1) # 32 for CIFAR | |
num_classes = dm.num_classes # 10 for CIFAR | |
# compute total number of steps | |
batch_size = BATCH * GPUS | |
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS | |
lm = MetaVisionTransformer( | |
steps=steps, | |
image_size=image_size, | |
num_classes=num_classes, | |
attention="scaled_dot_product", | |
residual_norm_style="pre", | |
feedforward="MLP", | |
use_rotary_embeddings=True, | |
) | |
trainer = pl.Trainer( | |
gpus=GPUS, | |
max_epochs=MAX_EPOCHS, | |
precision=16, | |
accumulate_grad_batches=REF_BATCH // BATCH, | |
) | |
trainer.fit(lm, dm) | |
# check the training | |
trainer.test(lm, datamodule=dm) | |