FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS:
# inspired by
# https://github.com/nateraw/lightning-vision-transformer
# which in turn references https://github.com/lucidrains/vit-pytorch
# Orignal author: Sean Naren
import math
from enum import Enum
import pytorch_lightning as pl
import torch
from pl_bolts.datamodules import CIFAR10DataModule
from torch import nn
from torchmetrics import Accuracy
from xformers.factory import xFormer, xFormerConfig
class Classifier(str, Enum):
GAP = "gap"
TOKEN = "token"
class VisionTransformer(pl.LightningModule):
def __init__(
self,
steps,
learning_rate=5e-4,
betas=(0.9, 0.99),
weight_decay=0.03,
image_size=32,
num_classes=10,
patch_size=2,
dim=384,
n_layer=6,
n_head=6,
resid_pdrop=0.0,
attn_pdrop=0.0,
mlp_pdrop=0.0,
attention="scaled_dot_product",
residual_norm_style="pre",
hidden_layer_multiplier=4,
use_rotary_embeddings=True,
linear_warmup_ratio=0.1,
classifier: Classifier = Classifier.TOKEN,
):
super().__init__()
# all the inputs are saved under self.hparams (hyperparams)
self.save_hyperparameters()
assert image_size % patch_size == 0
num_patches = (image_size // patch_size) ** 2
# A list of the encoder or decoder blocks which constitute the Transformer.
xformer_config = [
{
"block_type": "encoder",
"num_layers": n_layer,
"dim_model": dim,
"residual_norm_style": residual_norm_style,
"multi_head_config": {
"num_heads": n_head,
"residual_dropout": resid_pdrop,
"use_rotary_embeddings": use_rotary_embeddings,
"attention": {
"name": attention,
"dropout": attn_pdrop,
"causal": False,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": mlp_pdrop,
"activation": "gelu",
"hidden_layer_multiplier": hidden_layer_multiplier,
},
"position_encoding_config": {
"name": "learnable",
"seq_len": num_patches,
"dim_model": dim,
"add_class_token": classifier == Classifier.TOKEN,
},
"patch_embedding_config": {
"in_channels": 3,
"out_channels": dim,
"kernel_size": patch_size,
"stride": patch_size,
},
}
]
# The ViT trunk
config = xFormerConfig(xformer_config)
self.vit = xFormer.from_config(config)
print(self.vit)
# The classifier head
self.ln = nn.LayerNorm(dim)
self.head = nn.Linear(dim, num_classes)
self.criterion = torch.nn.CrossEntropyLoss()
self.val_accuracy = Accuracy()
@staticmethod
def linear_warmup_cosine_decay(warmup_steps, total_steps):
"""
Linear warmup for warmup_steps, with cosine annealing to 0 at total_steps
"""
def fn(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
progress = float(step - warmup_steps) / float(
max(1, total_steps - warmup_steps)
)
return 0.5 * (1.0 + math.cos(math.pi * progress))
return fn
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.learning_rate,
betas=self.hparams.betas,
weight_decay=self.hparams.weight_decay,
)
warmup_steps = int(self.hparams.linear_warmup_ratio * self.hparams.steps)
scheduler = {
"scheduler": torch.optim.lr_scheduler.LambdaLR(
optimizer,
self.linear_warmup_cosine_decay(warmup_steps, self.hparams.steps),
),
"interval": "step",
}
return [optimizer], [scheduler]
def forward(self, x):
x = self.vit(x)
x = self.ln(x)
if self.hparams.classifier == Classifier.TOKEN:
x = x[:, 0] # only consider the token, we're classifying anyway
elif self.hparams.classifier == Classifier.GAP:
x = x.mean(dim=1) # mean over sequence len
x = self.head(x)
return x
def training_step(self, batch, _):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.logger.log_metrics(
{
"train_loss": loss.mean(),
"learning_rate": self.lr_schedulers().get_last_lr()[0],
},
step=self.global_step,
)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
acc = self.val_accuracy(y_hat, y)
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def validation_step(self, batch, _):
self.evaluate(batch, "val")
def test_step(self, batch, _):
self.evaluate(batch, "test")
if __name__ == "__main__":
pl.seed_everything(42)
# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 128
MAX_EPOCHS = 30
NUM_WORKERS = 4
GPUS = 1
# We'll use a datamodule here, which already handles dataset/dataloader/sampler
# - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# for a full tutorial
# - Please note that default transforms are being used
dm = CIFAR10DataModule(
data_dir="data",
batch_size=BATCH,
num_workers=NUM_WORKERS,
pin_memory=True,
)
image_size = dm.size(-1) # 32 for CIFAR
num_classes = dm.num_classes # 10 for CIFAR
# compute total number of steps
batch_size = BATCH * GPUS
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS
lm = VisionTransformer(
steps=steps,
image_size=image_size,
num_classes=num_classes,
attention="scaled_dot_product",
classifier=Classifier.TOKEN,
residual_norm_style="pre",
use_rotary_embeddings=True,
)
trainer = pl.Trainer(
gpus=GPUS,
max_epochs=MAX_EPOCHS,
detect_anomaly=False,
precision=16,
accumulate_grad_batches=REF_BATCH // BATCH,
)
trainer.fit(lm, dm)
# check the training
trainer.test(lm, datamodule=dm)