CIFAR10 classification with transfer learning in PyTorch Lightning
There is a lot of mistakes that you can make when programming neural networks in PyTorch. Small nuances such as turning model.train() on when using dropout or batch normalization or forgetting writing model.eval() in your validation step are easy to miss in all those lines of code.
With PyTorch Lightning¹ you don’t have to care about any of these, neither write a for loop for training or .cuda() every time you want to send some parameters to your GPU.
Transfer learning is a very promising part of deep learning which allows you to use state of the art architecture trained on large datasets. This is particularly interesting if you have a small dataset. It is important to check that the distribution of the original dataset is similar to your data.
In this case, I will use EfficientNet² introduced in 2019 by Mingxing Tan and Quoc V. Le. EfficientNet achieves a state of the art result faster and with much fewer parameters than previous approaches.
CIFAR10 consists of 60000 images with dimensions 3x32x32 and 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship and truck. There are 6000 images for each class.
Imports
Firstly, import libraries.
import osfrom PIL import Imageimport torchfrom torch import nnfrom torchvision.datasets import CIFAR10import torch.nn.functional as Ffrom torchvision.datasets import MNISTfrom torch.utils.data import DataLoader, random_splitfrom torchvision import transformsimport pytorch_lightning as plfrom pytorch_lightning.metrics.functional import accuracyfrom pytorch_lightning import Trainer
DataModule
Pytorch makes it easy for us to load CIFAR10 directly from torchvision datasets. We will make use of pl.LightningDataModule to download the data and create training and validation DataLoader.
class CifarDataModule(pl.LightningDataModule):def __init__(self, data_dir = './' , batch_size=512):super().__init__()self.data_dir = data_dirself.batch_size = batch_sizeself.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])#Download datadef prepare_data(self):CIFAR10(self.data_dir, train=True, download=True)CIFAR10(self.data_dir, train=False, download=True)# Create train/val splitdef setup(self, stage=None):if stage == 'fit' or stage is None:cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
def train_dataloader(self):return DataLoader(self.cifar_train, batch_size=self.batch_size)
def val_dataloader(self):return DataLoader(self.cifar_val, batch_size=self.batch_size)
pl.LightningModule
The next step is to define our model and training with pl.LightningModule which inherits directly from nn.Module. The advantage of the lighting module is that it removes boilerplate code (notice no optimizer.step() etc.) but it is still the same old PyTorch.
With self.save_hyperparameters() there is no need to initialize parameters (self.lr = lr , self.hidden_size=hidden_size,…) but all of __init__ parameters are accessible from self.hparams.
I change the last linear layer so the number of outputs is equal to the number of classes in CIFAR10 and freeze all other layers for the first 5 epochs.
class LitCIFAR(pl.LightningModule):def __init__(self, hidden_size = 64, lr=2e-4, num_classes=10):super().__init__()self.save_hyperparameters()
# No need to write self.hidden_size=hidden_size # Define Model
# Download pretrained weights for transfer learningself.model = EfficientNet.from_pretrained('efficientnet-b7')# Freeze all layers just the classifierfor param in self.model.parameters():param.requires_grad = Trueself.model._fc = nn.Linear(self.model._fc.in_features, num_classes)def training_step(self, batch, batch_idx):x, y = batchif self.trainer.current_epoch == 5:for param in self.model.parameters():param.requires_grad=Truepreds = softmax(self.model(x), dim=1)loss = cross_entropy(preds, y)self.log('train_loss', loss)self.log('train_acc', accuracy(preds, y))return loss
def validation_step(self, batch, batch_idx):x, y = batchpreds = softmax(self.model(x), dim=1)loss = cross_entropy(preds, y)self.log('val_loss', loss)self.log('val_acc', accuracy(preds, y))return loss# Define optimizers and schedulersdef configure_optimizers(self):optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=2)return {'optimizer': optimizer, 'scheduler': scheduler, 'monitor':'vall_loss'}
I decided to use SGD with ReduceLROnPlateau scheduler which is both defined in configure optimizers method. The default logger for PyTorch lightning is TensorBoard where every scalar is outputted if you use self.log().
Callbacks
PyTorch Lightning contains a number of predefined callbacks with the most useful being EarlyStopping and ModelCheckpoint. However, it is possible to write any function and use it as a callback in trainer class.
EarlyStopping stops the training if the monitored metric doesn’t improve for a number of epochs defined in the patience parameter.
ModelCheckpoint saves the model with the best validation loss automatically.
from pytorch_lightning.callbacks import EarlyStoppingfrom pytorch_lightning.callbacks import ModelCheckpointearly_stopping = EarlyStopping(monitor='val_loss', patience=3, mode='min')model_checkpoint_path = '/content/drive/My Drive/...'checkpoint_callback = ModelCheckpoint(monitor='val_loss', #val_loss is default metricdirpath = model_checkpoint_path,filename = 'cifar10_scheduler-epoch{epoch:02d}-val_loss{val_loss:2f}',save_top_k=1, # How many top models to savemode='min')
Trainer
Trainer class is where the magic happens. It is initialized with a maximum number of epochs, GPUs used and a list of callbacks defined earlier.
If you use Jupyter Notebook, use progress_bar_refresh_rate parameter so it doesn’t freeze.
To start the training use the fit method with your pl.LightningDataModule and pl.DataModule.
cifar10 = CifarDataModule()model = LitCIFAR()trainer = pl.Trainer(max_epochs=20, progress_bar_refresh_rate=20,callbacks=[early_stopping, checkpoint_callback], gpus=1)trainer.fit(model, cifar10)
If you used self.log method in the training or validation step you can access the results.
This can be loaded with:
# Start tensorboard.%load_ext tensorboard%tensorboard --logdir lightning_logs/
To achieve higher accuracy we might try to train longer, fine-tune the learning rate, use more data augmentations or try different optimizers and schedulers.
There is much more to explore with Pytorch Lightning so check their website and docs!
Website: https://www.pytorchlightning.ai/
Docs: https://pytorch-lightning.readthedocs.io/en/latest/
If you liked this story and would like to get more tutorials on PyTorch and PyTorch Lightning consider buying me a coffee :)