Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
# CREDITS: | |
# inspired by | |
# https://github.com/nateraw/lightning-vision-transformer | |
# which in turn references https://github.com/lucidrains/vit-pytorch | |
# Orignal author: Sean Naren | |
import math | |
from enum import Enum | |
import pytorch_lightning as pl | |
import torch | |
from pl_bolts.datamodules import CIFAR10DataModule | |
from torch import nn | |
from torchmetrics import Accuracy | |
from xformers.factory import xFormer, xFormerConfig | |
class Classifier(str, Enum): | |
GAP = "gap" | |
TOKEN = "token" | |
class VisionTransformer(pl.LightningModule): | |
def __init__( | |
self, | |
steps, | |
learning_rate=5e-4, | |
betas=(0.9, 0.99), | |
weight_decay=0.03, | |
image_size=32, | |
num_classes=10, | |
patch_size=2, | |
dim=384, | |
n_layer=6, | |
n_head=6, | |
resid_pdrop=0.0, | |
attn_pdrop=0.0, | |
mlp_pdrop=0.0, | |
attention="scaled_dot_product", | |
residual_norm_style="pre", | |
hidden_layer_multiplier=4, | |
use_rotary_embeddings=True, | |
linear_warmup_ratio=0.1, | |
classifier: Classifier = Classifier.TOKEN, | |
): | |
super().__init__() | |
# all the inputs are saved under self.hparams (hyperparams) | |
self.save_hyperparameters() | |
assert image_size % patch_size == 0 | |
num_patches = (image_size // patch_size) ** 2 | |
# A list of the encoder or decoder blocks which constitute the Transformer. | |
xformer_config = [ | |
{ | |
"block_type": "encoder", | |
"num_layers": n_layer, | |
"dim_model": dim, | |
"residual_norm_style": residual_norm_style, | |
"multi_head_config": { | |
"num_heads": n_head, | |
"residual_dropout": resid_pdrop, | |
"use_rotary_embeddings": use_rotary_embeddings, | |
"attention": { | |
"name": attention, | |
"dropout": attn_pdrop, | |
"causal": False, | |
}, | |
}, | |
"feedforward_config": { | |
"name": "MLP", | |
"dropout": mlp_pdrop, | |
"activation": "gelu", | |
"hidden_layer_multiplier": hidden_layer_multiplier, | |
}, | |
"position_encoding_config": { | |
"name": "learnable", | |
"seq_len": num_patches, | |
"dim_model": dim, | |
"add_class_token": classifier == Classifier.TOKEN, | |
}, | |
"patch_embedding_config": { | |
"in_channels": 3, | |
"out_channels": dim, | |
"kernel_size": patch_size, | |
"stride": patch_size, | |
}, | |
} | |
] | |
# The ViT trunk | |
config = xFormerConfig(xformer_config) | |
self.vit = xFormer.from_config(config) | |
print(self.vit) | |
# The classifier head | |
self.ln = nn.LayerNorm(dim) | |
self.head = nn.Linear(dim, num_classes) | |
self.criterion = torch.nn.CrossEntropyLoss() | |
self.val_accuracy = Accuracy() | |
def linear_warmup_cosine_decay(warmup_steps, total_steps): | |
""" | |
Linear warmup for warmup_steps, with cosine annealing to 0 at total_steps | |
""" | |
def fn(step): | |
if step < warmup_steps: | |
return float(step) / float(max(1, warmup_steps)) | |
progress = float(step - warmup_steps) / float( | |
max(1, total_steps - warmup_steps) | |
) | |
return 0.5 * (1.0 + math.cos(math.pi * progress)) | |
return fn | |
def configure_optimizers(self): | |
optimizer = torch.optim.AdamW( | |
self.parameters(), | |
lr=self.hparams.learning_rate, | |
betas=self.hparams.betas, | |
weight_decay=self.hparams.weight_decay, | |
) | |
warmup_steps = int(self.hparams.linear_warmup_ratio * self.hparams.steps) | |
scheduler = { | |
"scheduler": torch.optim.lr_scheduler.LambdaLR( | |
optimizer, | |
self.linear_warmup_cosine_decay(warmup_steps, self.hparams.steps), | |
), | |
"interval": "step", | |
} | |
return [optimizer], [scheduler] | |
def forward(self, x): | |
x = self.vit(x) | |
x = self.ln(x) | |
if self.hparams.classifier == Classifier.TOKEN: | |
x = x[:, 0] # only consider the token, we're classifying anyway | |
elif self.hparams.classifier == Classifier.GAP: | |
x = x.mean(dim=1) # mean over sequence len | |
x = self.head(x) | |
return x | |
def training_step(self, batch, _): | |
x, y = batch | |
y_hat = self(x) | |
loss = self.criterion(y_hat, y) | |
self.logger.log_metrics( | |
{ | |
"train_loss": loss.mean(), | |
"learning_rate": self.lr_schedulers().get_last_lr()[0], | |
}, | |
step=self.global_step, | |
) | |
return loss | |
def evaluate(self, batch, stage=None): | |
x, y = batch | |
y_hat = self(x) | |
loss = self.criterion(y_hat, y) | |
acc = self.val_accuracy(y_hat, y) | |
if stage: | |
self.log(f"{stage}_loss", loss, prog_bar=True) | |
self.log(f"{stage}_acc", acc, prog_bar=True) | |
def validation_step(self, batch, _): | |
self.evaluate(batch, "val") | |
def test_step(self, batch, _): | |
self.evaluate(batch, "test") | |
if __name__ == "__main__": | |
pl.seed_everything(42) | |
# Adjust batch depending on the available memory on your machine. | |
# You can also use reversible layers to save memory | |
REF_BATCH = 512 | |
BATCH = 128 | |
MAX_EPOCHS = 30 | |
NUM_WORKERS = 4 | |
GPUS = 1 | |
# We'll use a datamodule here, which already handles dataset/dataloader/sampler | |
# - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | |
# for a full tutorial | |
# - Please note that default transforms are being used | |
dm = CIFAR10DataModule( | |
data_dir="data", | |
batch_size=BATCH, | |
num_workers=NUM_WORKERS, | |
pin_memory=True, | |
) | |
image_size = dm.size(-1) # 32 for CIFAR | |
num_classes = dm.num_classes # 10 for CIFAR | |
# compute total number of steps | |
batch_size = BATCH * GPUS | |
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS | |
lm = VisionTransformer( | |
steps=steps, | |
image_size=image_size, | |
num_classes=num_classes, | |
attention="scaled_dot_product", | |
classifier=Classifier.TOKEN, | |
residual_norm_style="pre", | |
use_rotary_embeddings=True, | |
) | |
trainer = pl.Trainer( | |
gpus=GPUS, | |
max_epochs=MAX_EPOCHS, | |
detect_anomaly=False, | |
precision=16, | |
accumulate_grad_batches=REF_BATCH // BATCH, | |
) | |
trainer.fit(lm, dm) | |
# check the training | |
trainer.test(lm, datamodule=dm) | |