In this tutorial I demonstrate how Keras can be used to quickly implement deep learning architectures for rapid prototyping. Specifically, I build and train two networks on the CIFAR-10 dataset with image augmentation:
Why Keras? In my earlier tutorial I demonstrated how to use the low-level TensorFlow API to build and train a simple convolutional neural network (CNN). But such an approach is code-heavy and quickly becomes cumbersome when trying to rapidly prototype multiple ideas. Keras was designed to eliminate much of this overhead, making it faster to move from idea to experiment.
As well-articulated by Keras creator Francois Chollet in his post on the recent success of Keras in winning Kaggle competitions:
"Your final models will typically share little in common with the solutions you envisioned when first approaching the problem, because ...a-priori plans basically never survive confrontation with experimental reality. So winning is not so much about how good your theoretical vision is, it's about how much contact with reality your vision has been through. ...[Using] Keras allows you to iterate faster, to try more things. Ultimately, this enables you to win competitions (or publish papers). Being able to go from idea to result with the least possible delay is key to doing good research —that’s one of the core beliefs behind Keras."
import pickle
import os
import math
import numpy as np
from datetime import datetime
from matplotlib import pyplot as plt
from IPython.display import Image
import tensorflow as tf
import keras as K
from keras.layers import Input, Conv2D, Dense
from keras.layers import BatchNormalization, SpatialDropout2D
from keras.layers import Activation, Dropout, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import TensorBoard
from keras.callbacks import LearningRateScheduler
from keras.utils.vis_utils import plot_model
print('Keras Version:', K.__version__)
print('Keras Backend:', K.backend.backend())
print('Keras Image Data Format:', K.backend.image_data_format())
print('TensorFlow Version:', tf.__version__)
The CIFAR-10 dataset (found here) consists of 32x32-pixel RGB colour images of 10 classes (see below), with 6000 image samples per class. There are 50,000 training images and 10,000 test images in total. They were collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton at the University of Toronto, and are a labeled subset of the "80 million tiny images" dataset.
The 10 classes are:
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
General instructions for unpacking the data are found here. Each image data vector has 3072 elements. The first 1024 entries contain the red channel values, the next 1024 are the green values, and the final 1024 are the blue values. We reshape these vectors into arrays of dimension (32, 32, 3), where the third axis is the RGB channel. We reserve the final 5000 training images for our validation set.
# define data directory and helper functions
DATA_PATH = "/home/rpm/ML-models/cifar-10-batches-py/"
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def load_batch(batch_name):
batch = os.path.join(DATA_PATH, batch_name)
batch_dict = unpickle(batch)
print('Loading %s' % batch_dict[b'batch_label'].decode('ascii'))
X_data_ = batch_dict[b'data']
X_data_ = X_data_.reshape(10000, 3, 32, 32) # separate RGB channels
X_data_ = X_data_.transpose(0,2,3,1).astype("float") # re-order
y_labels_ = batch_dict[b'labels']
img_filenames_ = batch_dict[b'filenames']
return (X_data_, y_labels_, img_filenames_)
# load training data
TRAIN_BATCH_LIST = [("data_batch_%s" % i) for i in range(1,6)]
X_arrs, y_arrs, name_arrs = [], [], []
for batch_name in TRAIN_BATCH_LIST:
(X_arr, y_arr, name_arr) = load_batch(batch_name)
X_arrs.append(X_arr)
y_arrs.append(np.array(y_arr))
name_arrs.append(name_arr)
X_train = np.vstack(X_arrs).astype('float32')
y_train = np.hstack(y_arrs)
img_filenames_train = np.vstack(name_arrs)
# load test data
(X_test, y_test, img_filenames_test) = load_batch('test_batch')
y_test = np.array(y_test)
# ensure desired data format
X_train = X_train.astype('float32')
y_train = y_train.astype('float32')
X_test = X_test.astype('float32')
y_test = y_test.astype('float32')
# define label dictionary
label_dict = {0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'}
# use final 5000 training samples (10%) as validation set
X_val = X_train[45000:].astype('float32')
y_val = y_train[45000:]
X_train = X_train[:45000].astype('float32')
y_train = y_train[:45000]
# print data shapes
print('\nX_train shape:', X_train.shape)
print('X_val shape:', X_val.shape)
print('X_test shape:', X_test.shape)
print('y_train shape:', y_train.shape)
print('y_val shape:', y_val.shape)
print('y_test shape:', y_test.shape)
Let's have a look at 8 sample images from each class:
for j in range(10):
print(label_dict[j])
fig, ax = plt.subplots(nrows=1, ncols=8,
sharex=True, sharey=True,
figsize=(12, 8))
ax = ax.flatten()
for i in range(8):
img = X_train[y_train == j][i].astype('uint8')
ax[i].imshow(img, interpolation='None')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.show()
First let's use Keras's pre-built onehot encoder to define a convenient function for one-hot encoding our class labels later on during model training and evaluation.
def get_onehot(y_data, num_classes=10):
return K.utils.to_categorical(y_data, num_classes)
Next we'll implement image augmentation, which is a technique for improving the generalization of our models by applying randomized image transformations (such as shifts, rotations, shears, zooms, and flips) to the original dataset. This helps force the model to learn feature maps that are more robust and relevant. We are essentially creating a larger training dataset from a smaller one. For excellent introductory articles on this topic, see these posts by Bharath Raj and Ryan Allred.
For our WRN model, it turns out that image augmentation is essential, otherwise our network basically just memorizes our training set (training accuracy reaches 100%) and yields poor generalization to the test set.
Fortunately Keras has built-in tools for applying image augmentation, with examples available here.
We'll base our image augmentation on the scheme recommended by Sergey Zagoruyko and Nikos Komodakis in their seminal paper on Wide ResNets (WRNs). Quoting their paper:
"For data augmentation we do horizontal flips and take random crops from image padded by 4 pixels on each side, filling missing pixels with reflections of original image. We don’t use heavy data augmentation... Unless mentioned otherwise, for CIFAR we follow the image preprocessing of [8] with ZCA whitening."
Note that ZCA whitening is a linear algebra transformation, similar to principal component analysis (PCA), that reduces pixel redundancy to better highlight the key structures and features in the image. We'll use ZCA whitening to preprocess and standardize our data. We'll also look at how it compares to typical data standardization (mean/std) for our particular WRN model.
Below we create generators for augmented images that also apply ZCA whitening and feature normalization.
# augmented image generator based on scheme
# recommended by Zagoruyko and Komodakis
train_datagen_1 = ImageDataGenerator(
featurewise_center=True,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=True,
zca_epsilon=1e-06,
rotation_range=0.0,
width_shift_range=0.25,
height_shift_range=0.25,
brightness_range=None,
shear_range=0.0,
zoom_range=0.0,
channel_shift_range=0.0,
fill_mode='reflect',
cval=0.0,
horizontal_flip=True,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format='channels_last',
validation_split=0.0)
# an alternative image augmentation scheme
# using zooms, rotations, and shears
train_datagen_2 = ImageDataGenerator(
featurewise_center=True,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=True,
zca_epsilon=1e-06,
rotation_range=15.0,
width_shift_range=0.1,
height_shift_range=0.1,
brightness_range=None,
shear_range=0.2,
zoom_range=0.2,
channel_shift_range=0.0,
fill_mode='nearest',
cval=0.0,
horizontal_flip=True,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format='channels_last',
validation_split=0.0)
# for the test set, apply ZCA and centering only
test_datagen = ImageDataGenerator(
featurewise_center=True,
zca_whitening=True,
data_format='channels_last')
# data normalization and ZCA whitening require our
# generators to first be fit on our training data
train_datagen_1.fit(np.vstack((X_train, X_val)))
train_datagen_2.fit(np.vstack((X_train, X_val)))
test_datagen.fit(np.vstack((X_train, X_val)))
Let's have a look at what our image augmentation schemes produce for a sample input image. To make this easier to see, I've defined two new generators with the ZCA whitening transformation removed.
# sample image to augment
x_temp = X_train[4].reshape(1, 32, 32, 3)
# recreate the generators but without ZCA whitening
demo_datagen_1 = ImageDataGenerator(
rotation_range=0.0,
width_shift_range=0.25,
height_shift_range=0.25,
fill_mode='reflect',
horizontal_flip=True,
data_format='channels_last')
demo_datagen_2 = ImageDataGenerator(
rotation_range=15.0,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.2,
zoom_range=0.2,
fill_mode='nearest',
horizontal_flip=True,
data_format='channels_last')
# plot examples from scheme 1
print('\nData Augmentation Scheme 1:')
i=0
figure = plt.figure(figsize=(14, 8))
for batch in demo_datagen_1.flow(x_temp, batch_size=1):
# display the images
ax = figure.add_subplot(4, 5, i + 1, xticks=[], yticks=[])
img = np.squeeze(batch).astype('uint8')
ax.imshow(img, interpolation='nearest')
i+=1
if i==20: # otherwise the generator would continue indefinitely
break
# plot examples from scheme 2
print('\nData Augmentation Scheme 2:')
i=0
figure = plt.figure(figsize=(14, 8))
for batch in demo_datagen_2.flow(x_temp, batch_size=1):
# display the images
ax = figure.add_subplot(4, 5, i + 1, xticks=[], yticks=[])
img = np.squeeze(batch).astype('uint8')
ax.imshow(img, interpolation='nearest')
i+=1
if i==20: # otherwise the generator would continue indefinitely
break
We'll first turn our attention to the famous VGG (Visual Geometry Group) architecture, which was highly influential in its approach to convolutional networks for image recognition. Compared to earlier architectures, VGG was remarkable in both its simplicity and its layer depth. It is based on a heirarchical representation of visual data that makes it possible for smaller convolutional filter sizes (e.g. 3x3 pixels) to capture rich feature information so long as the network is sufficiently deep. To better appreciate VGG's place in the evolution of state-of-the-art networks for image recognition, I recommend reading this article on the history of neural networks by Prof. Eugenio Culurciello at Purdue University.
The core design rules of VGG network construction are as follows:
- For a given output feature map size (e.g. 32x32 or 16x16), you use the same number filters (channels) for each convolution layer.
- Each time the feature map size is halved (e.g. by maxpooling with a stride of 2), you double the number of filters (channels) to maintain the same "time complexity per layer".
- You end with a global average pooling layer followed by fully-connected (dense) layers and softmax.
A table summarizing various VGG models, taken from the original research paper, is shown below, where convN-K
refers to a convolutional layer having filter size NxN
pixels and K
filter channels.
The relative simplicity of VGG makes it an excellent first example for showing how we can build such networks in Keras. On the other hand, its large number of trainable parameters makes it computationally expensive. Hence we'll implement a simplified version of VGG that is similar to configuration B (see table above) but has one less convolution block and only one fully-connected dense layer at the output.
For architectures such as VGG where all network layers are arranged in serial, we can use Keras' Sequential API
for model construction. We first create an instance of the Sequential()
model object, and then chronologically add layers. The pre-built network layers we'll be using include Conv2D
, BatchNormalization
, MaxPooling2D
, and Dense
. A Flatten
layer is used to concatenate and reshape all feature maps into a single vector prior to applying the fully-connected layer. The handy summary()
method allows us to view details about the model, including the number of trainable and non-trainable parameters.
def build_SVGG():
model = K.Sequential()
# 1. First Conv block
model.add(Conv2D(filters=64,
kernel_size=3,
padding='same',
activation='relu',
input_shape=(32, 32, 3)))
model.add(BatchNormalization(axis=-1))
model.add(Conv2D(filters=64,
kernel_size=3,
padding='same',
activation='relu'))
model.add(BatchNormalization(axis=-1))
model.add(MaxPooling2D(pool_size=2))
# 2. Second Conv block
model.add(Conv2D(filters=128,
kernel_size=3,
padding='same',
activation='relu'))
model.add(BatchNormalization(axis=-1))
model.add(Conv2D(filters=128,
kernel_size=3,
padding='same',
activation='relu'))
model.add(BatchNormalization(axis=-1))
model.add(MaxPooling2D(pool_size=2))
# 3. Third Conv block
model.add(Conv2D(filters=256,
kernel_size=3,
padding='same',
activation='relu'))
model.add(BatchNormalization(axis=-1))
model.add(Conv2D(filters=256,
kernel_size=3,
padding='same',
activation='relu'))
model.add(BatchNormalization(axis=-1))
model.add(MaxPooling2D(pool_size=2))
# 4. Dense layer with dropout applied to flattened data
model.add(K.layers.Flatten())
model.add(K.layers.Dense(1024, activation='relu'))
model.add(K.layers.Dropout(0.5))
model.add(K.layers.Dense(10, activation='softmax'))
return model
# for inspection purposes
model_SVGG = build_SVGG()
model_SVGG.summary()
We see that this simplified VGG model has roughly 5.3 million parameters.
Before we train our model, we need to compile it using the compile()
method. It is at this time that we specify:
For example, to train a model using categorical cross-entropy loss with optimization via stochastic gradient descent while tracking accuracy as the performance metric, we use:
model_SVGG.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
There are a few other things we need to know about before proceeding. One is about the use of callbacks. Callbacks allow us to apply functions or access model parameters during the training process. For example, we can use callbacks to access the model weights after each training iteration and save these to file if they result in a better minimization of the loss function than previous iterations. We can also use callbacks to generate TensorBoard files for viewing a model's performance and for detailed visualization of its computational graph. For more about callbacks, see the Keras documentation.
Finally, one typically trains, evaluates, and utilizes Keras models by calling the fit()
, evaluate()
, and predict()
methods respectively. In this case, however, when working with data generators for image augmentation, we need to call slightly different methods, namely fit_generator()
, evaluate_generator()
, and predict_generator()
, where our train and test set data generators are given as arguments. Note that we cannot mix these sets of methods: even though we don't use image augmentation on the test set, we cannot simply standardize the test data and call predict()
after training a model with fit_generator()
; doing so will yield erroneous results, hence why we use a generator for the test set.
Calling fit()
or fit_generator()
on a model object returns a history object that we can use to plot training and validation accuracy as a function of training epochs, which is great for assessing how well the model converges and its generalization (e.g. bias-variance tradeoff).
Below we define some helper functions that make it easier for us to train and evaluate our Keras models with fewer lines of repeated code:
# perform all tasks related to model training
def compile_and_train(model,
MODEL_NAME,
X_train_data,
y_train_data,
X_val_data,
y_val_data,
train_generator,
test_generator,
BATCH_SIZE,
EPOCHS,
optimizer='rmsprop',
lr_schedule_fn=None,
starting_epoch=0):
# directory setup for saving model files
MODEL_DIR = '/home/rpm/Dropbox/Keras-Devs/models/CIFAR-10/'
DIR = os.path.join(
MODEL_DIR,
MODEL_NAME,
datetime.now().strftime('%Y%m%d-%H%M%S'))
if not os.path.exists(DIR):
os.makedirs(DIR)
LOG_DIR = os.path.join(DIR, 'logs')
# checkpointer callback for saving best weights
WEIGHT_DIR = os.path.join(DIR, 'best_weights.hdf5')
checkpointer_cb = K.callbacks.ModelCheckpoint(
filepath=WEIGHT_DIR,
verbose=1,
save_best_only=True)
# callback to save tensorboard logs
tensorboard_cb = TensorBoard(
log_dir=LOG_DIR,
histogram_freq=0,
batch_size=BATCH_SIZE,
write_graph=True,
write_grads=False,
write_images=False)
print('To view graph generated by TensorBoard,')
print('copy+paste the following into terminal:')
print('\ntensorboard --logdir=%s' % LOG_DIR)
print('\n')
# special callbacks for a scheduled learning rate
if lr_schedule_fn is not None:
# learning rate scheduler
lr_schedule_cb =K.callbacks.LearningRateScheduler(
lr_schedule, verbose=0)
# custom history callback
class LossHistory(K.callbacks.Callback):
def on_train_begin(self, logs={}):
self.losses = []
self.lr = []
def on_epoch_end(self, batch, logs={}):
self.losses.append(logs.get('loss'))
self.lr.append(lr_schedule(len(self.losses)))
loss_and_lr_history_cb = LossHistory()
callback_list = [checkpointer_cb,
tensorboard_cb,
lr_schedule_cb,
loss_and_lr_history_cb]
else:
callback_list = [checkpointer_cb,
tensorboard_cb]
# compile the model
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# one-hot encode labels
y_train_data_onehot = get_onehot(y_train_data)
y_val_data_onehot = get_onehot(y_val_data)
# prepare data generators
train_gen = train_generator.flow(
X_train_data,
y_train_data_onehot,
batch_size=BATCH_SIZE)
val_gen = test_generator.flow(
X_val_data,
y_val_data_onehot,
batch_size=BATCH_SIZE)
# train the model
history= model.fit_generator(
generator=train_gen,
steps_per_epoch=len(X_train_data)/BATCH_SIZE,
epochs=EPOCHS,
verbose=1,
callbacks=callback_list,
validation_data=val_gen,
validation_steps=len(X_val_data)/BATCH_SIZE,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=starting_epoch)
# save trained model (after final epoch)
model.save(os.path.join(DIR, 'model_WRN.h5'))
if lr_schedule_fn is not None:
return (history, DIR, loss_and_lr_history_cb)
else:
return (history, DIR)
# function for plotting metrics versus
# training epoch from a history object
def plot_histories(history_):
# summarize history for accuracy
plt.plot(history_.history['acc'])
plt.plot(history_.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['training', 'validation'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history_.history['loss'])
plt.plot(history_.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['training', 'validation'], loc='upper left')
plt.show()
return None
# helper function for evaluating performance
def evaluate_performance(model_obj,
model_dir,
X_test_data,
y_test_data,
test_generator,
model_hist=None,
model_lr_hist=None):
# plot accuracies versus epoch
if model_hist is not None:
print('\nPlotting Training History:')
plot_histories(model_hist)
# plot learning rate if scheduled
if model_lr_hist is not None:
print('\nPlotting Learning Rate Schedule:')
plt.plot(model_lr_hist.lr)
plt.ylabel('Learning Rate')
plt.xlabel('Epoch')
plt.show()
plt.plot(model_lr_hist.losses)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()
# one-hot encoding
y_test_data_onehot = get_onehot(y_test_data)
# prepare data generator
test_gen = test_generator.flow(
X_test_data,
y_test_data_onehot,
batch_size=1)
# evaluate with End-Of-Training weights
print('\nComputing EOT Score...')
score_EOT = model_obj.evaluate_generator(
generator=test_gen,
steps=len(y_test_data),
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=1)
print('Test Accuracy with EOT Weights:', score_EOT[1])
# evaluate with Best Weights
model_obj.load_weights(
os.path.join(model_dir, 'best_weights.hdf5'))
print('\nComputing Best Weights Score...')
score_BW = model_obj.evaluate_generator(
generator=test_gen,
steps=len(y_test_data),
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=1)
print('Test Accuracy with Best Weights:', score_BW[1])
scores = (score_EOT[1], score_BW[1])
return scores
With our handy helper functions, we can now build and train the simplified VGG network in a few short chunks of code. Note that I'm substituting the test set as the validation set here, so we can see how the test accuracy evolves with model training --- I can do this here because in no way is the test set being used to update the model weights during training or to select hyperparameters. Although I made a validation set split earlier out of habit, here I'm training the model on the full 50000-sample training set.
# reset the tensorflow graph
K.backend.clear_session()
# construct model
model_SVGG = build_SVGG()
# initiate training
this_name = 'SVGG_AugScheme1_FullTrainSet'
(this_history, this_dir) = compile_and_train(
model=model_SVGG,
MODEL_NAME=this_name,
X_train_data=np.vstack((X_train, X_val)),
y_train_data=np.hstack((y_train, y_val)),
X_val_data=X_test,
y_val_data=y_test,
train_generator=train_datagen_1,
test_generator=test_datagen,
BATCH_SIZE=128,
EPOCHS=400,
optimizer='rmsprop',
lr_schedule_fn=None)
Now let's run our evaluation helper function to see how this model did:
# function for evaluating performance
this_score = evaluate_performance(
model_obj=model_SVGG,
model_dir=this_dir,
X_test_data=X_test,
y_test_data=y_test,
test_generator=test_datagen,
model_hist=this_history,
model_lr_hist=None)
Performance on the training set saturates fairly quickly (around epoch 100) whereas the loss function for the valdiation set appears to reach an optimum near 50 epochs and then fluctuates significantly. To obtain better performance, we could try a different optimizer or specify a learning rate schedule. For now, we'll note 86.4% as the test accuracy obtained by this model.
Let's view a random sample of the images our trained model has classified correctly or incorrectly. The following code accomplishes this.
# generate the predictions
test_generator = test_datagen.flow(X_test,
get_onehot(y_test),
batch_size=1)
probas = model_SVGG.predict_generator(test_generator,
steps=len(y_test))
preds_SVGG = np.argmax(probas, axis=1)
# function for viewing predictions
def view_predictions(model_, preds_, condition='None'):
if condition is None:
print('\n**Displaying 20 Randomly-Selected Predictions**')
inds = np.arange(len(preds_))
elif condition is 'Correct':
print('\n**Displaying Correct Predictions**')
inds = np.arange(len(preds_))[preds_ == y_test]
elif condition is 'Incorrect':
print('\n**Displaying Incorrect Predictions**')
inds = np.arange(len(preds_))[preds_ != y_test]
else:
raise ValueError(
'condition must be: "None", "Correct" or "Incorrect"')
# plot 20 randomly-chosen samples
figure = plt.figure(figsize=(14, 8))
for i, index in enumerate(np.random.choice(inds,
size=20,
replace=False)):
# display the images
ax = figure.add_subplot(4, 5, i + 1, xticks=[], yticks=[])
ax.imshow(X_test[index].astype('uint8'), interpolation='None')
# from one-hot encoded back to integer label
pred_int = preds_[index]
true_int = y_test[index]
# set title: Prediction (True Class)
ax.set_title("{} ({})".format(
label_dict[pred_int],
label_dict[true_int]),
color=("green" if pred_int == true_int else "red"))
return None
# correct predictions
view_predictions(model_SVGG, preds_SVGG, condition='Correct')
# incorrect predictions
view_predictions(model_SVGG, preds_SVGG, condition='Incorrect')
When dealing with convolutional neural networks, it's not always easy to understand why a particular image has been misclassified. However, if desired, we could examine the weights of the filter channels and network activations to get a better sense of which features are associated with each class.
Next we turn our attention to the slightly more challenging task of building a Wide Residual Network (WRN). Residual Networks (or ResNets) were considered a significant advance when they were first introduced.
Why ResNets?
Prior to ResNets, the trend (with AlexNet, VGG, and Google's Inception networks) was to go deeper to achieve better performance. The issue with this was that deep networks are hard to train due to the vanishing gradient problem -- as the gradient is back-propagated to earlier layers, the repeated multiplication of small numbers can make it vanish, meaning that as the network gets deeper, its performance saturates or even starts to degrade.
The core idea of ResNets is to introduce an "identity shortcut connection" (illustrated below, from the original paper) that skips one or more layers, thereby creating additional shorter paths through which the gradient can flow. This allows for the training of extremely deep models with hundreds layers -- in some cases, more than one thousand -- before vanishing gradients become problematic. This in turn led to new record performance on several benchmark image classification datasets.
Why "Wide" ResNets?
It was later discovered, however, that randomly removing ResNet layers often had little to no impact on performance. Many ResNet layers appeared to be redundant or contributed negligible value. One explanation for this was that the large multiplicity of paths over which the gradient could flow meant that many ResNet layer weights could be effectively 'skipped over' during backpropagation and hence were not participating efficiently in learning.
This led to further experimentation, and it was found that making ResNets "wider", as opposed to deeper, could result in improved performance, far fewer layers, and faster training. An oft-quoted remark is that a 6-layer Wide ResNet (WRN) can have an accuracy similar to that of a 1000-layer deep conventional ResNet.
WRN Design Rules The basic design rules for WRNs, as laid out in this seminal paper by Sergey Zagoruyko and Nikos Komodakis, are as follows:
- Convolution filters should be no larger than 3x3 pixels for best performance (unless a bottleneck architecture is used).
- Follow the convention that when you downsample by 2, you double the number of feature maps (filter channels).
- Three important factors are introduced:
l
, the deepening factor, which represents the number of convolutional layers in a ResNet block (often l = 2 is optimal);N
, the number of ResNet blocks (repeats; miniblocks) within a given convolutional group (wherein all layers in a group have the same fixed number of filter channels); andK
, the widening factor, which multiplies the number of feature channels in each convolution layer.- A batch normalization and ReLU activation precedes all convolution layers in ResNet blocks beyond the first block.
Following these guidelines, the basic structure of a WRN given by the paper above is as follows:
The conv1 group is the network input and has only one conv layer. Groups conv2, conv3, and conv4 are the residual block conv groups whose parameters are specified by l
(implicitly l = 2
here), N
, and K
.
The common naming convention for identifying specific realizations of WRNs (with l = 2) is WRN-n-K
where n
is the total number of convolutional layers, and K is the widening factor. Note that n
can take only specific values: we have 3 conv groups, each of which have 2 conv layers multiplied by N
the number of miniblocks in each group, and there are always 4 additional conv layers (one in the input, and one at the beginning of each conv group in the residual path to match feature map sizes between groups). In our case, we'll work with WRN-40-2, which means that we have (40 -4)/(3x2) = 6 miniblocks in conv groups 2-4, and a widening factor of K=2.
For more details, see the WRN creators' github page for a network diagram. Their model was implemented in Torch.
Additional Notes
Zagoruyko and Komodakis reported a performance improvement when dropout was added between convolution layers. Note that this is typically not done when batch normalization (BN) is used, as dropout and BN often don't work well together (see this paper). While BN is mostly a technique for improving optimization, it also happens to introduce noise, which can help with model generalization. And since for large datasets it is usually more important to optimize than to regularize well, BN is often preferred over dropout.
Nonetheless, there can be exceptions where BN and dropout seem to complement each other (see here, for example). I have chosen to follow the suggestion of Zagoruyko and Komodakis and include dropout, and have found that it does help here in improving the model performance.
I have also chosen to follow their recommendations on the optimization technique:
"In all our experiments we use SGD with Nesterov momentum and cross-entropy loss. The initial learning rate is set to 0.1, weight decay to 0.0005, dampening to 0, momentum to 0.9 and minibatch size to 128."
The Keras functional API makes it easy to define layers that branch into parallel paths and the re-merge later. Note that we can also define auxillary inputs and loss functions, as well as define our own custom layer functions. For more, see the official documentation.
To use the functional API, we simply deploy our Keras built-in layers as functions like so:
output_layer = LayerType(layer_parameters)(input_layer)
The logic of how this works will quickly become apparent in the following code blocks.
def resnet_input_block(input_layer,
num_filters):
x = Conv2D(filters=num_filters,
kernel_size=3,
padding='same')(input_layer)
x = BatchNormalization(axis=-1)(x)
x_out = Activation('relu')(x)
return x_out
def resnet_block(x_in,
num_filters,
num_miniblocks,
downsampling,
dropout_rate=0.0):
'''
** Input Parameters **
x_in : input layer tensor (batchsize, width, height, channels)
num_filters : total number of feature maps including
the widening factor multiplier
num_miniblocks : number of times we repeat the miniblock
(a.k.a. "groupsize" in the WRN paper)
downsampling : factor by which we reduce the output size,
i.e. downsampling of 2 takes 16x16 --> 8x8;
downsampling here is applied by adjusting
the stride of the first convolutional layers
rather than via pooling layers.
** Returns **
x_out : output layer for the next network block
'''
# residual path
x_res = Conv2D(filters=num_filters,
kernel_size=1,
padding='same',
strides=downsampling)(x_in)
# main path mini-block
x_main = Conv2D(filters=num_filters,
kernel_size=3,
padding='same',
strides=downsampling)(x_in)
x_main = BatchNormalization(axis=-1)(x_main)
x_main = Activation('relu')(x_main)
x_main = SpatialDropout2D(rate=dropout_rate)(x_main)
x_main = Conv2D(filters=num_filters,
kernel_size=3,
padding='same')(x_main)
# merge
x_out = K.layers.Add()([x_main, x_res])
# *** additional miniblocks ***
for block in range(num_miniblocks-1):
# main path mini-block
x_main = BatchNormalization(axis=-1)(x_out)
x_main = Activation('relu')(x_main)
x_main = Conv2D(filters=num_filters,
kernel_size=3,
padding='same')(x_main)
x_main = SpatialDropout2D(rate=dropout_rate)(x_main)
x_main = BatchNormalization(axis=-1)(x_main)
x_main = Activation('relu')(x_main)
x_main = Conv2D(filters=num_filters,
kernel_size=3,
padding='same')(x_main)
# merge
x_out = K.layers.Add()([x_main, x_out])
# ***
# capping the resnet block
x = BatchNormalization(axis=-1)(x_out)
x_out = Activation('relu')(x)
return x_out
def resnet_output_block(x_in):
# auto-adjust pooling based on input shape
dim1_ = x_in.shape[1].value
dim2_ = x_in.shape[2].value
assert dim1_ == dim2_, 'Input layer dimensions must be square.'
# this generates a single average value for each feature map
x = K.layers.AveragePooling2D(pool_size=dim1_)(x_in)
# flatten and apply a fully-connected layer
x = K.layers.Flatten()(x)
x = Dense(128, activation='relu')(x)
# obtain probabilities for each class
class_probas = Dense(10, activation='softmax')(x)
return class_probas
def build_WRN(num_convs=40, k_factor=2, drop_rate=0.0):
"""
Builds a wide residual network (WRN) of type
WRN-N-K where N is the total number of convolutions
and K is the widening factor.
** Input Parameters **
num_convs: Total number of convolutions. Must be an
integer value derivable from (i * 2 * 3) + 4,
where i is an integer (e.g. 10, 16, 22, 28,
34, and 40 are acceptable values)
k_factor: Widening factor, multiplies the number of channels
in each resnet block convolution. Must be an
integer value.
drop_rate: Fraction of features dropped during SpatialDropout2D
between conv layers in each residual block.
** Returns **
model: WRN model object built with Keras' functional API.
"""
assert type(num_convs) is int, \
"num_convs is not an integer: %r" % num_convs
assert type(k_factor) is int, \
"k_factor is not an integer: %r" % k_factor
assert num_convs in [10, 16, 22, 28, 34, 40], \
"num_convs must be one of [10, 16, 22, 28, 34, 40]"
num_miniblocks = int(math.ceil((num_convs - 4)/6))
input_layer = Input(shape=(32, 32, 3))
# input block ('conv1' group in WRN paper)
x = resnet_input_block(input_layer, num_filters=16)
# ('conv2' group in WRN paper)
x = resnet_block(x_in=x,
num_filters=16*k_factor,
num_miniblocks=num_miniblocks,
downsampling=2,
dropout_rate=drop_rate)
# ('conv3' group in WRN paper)
x = resnet_block(x_in=x,
num_filters=32*k_factor,
num_miniblocks=num_miniblocks,
downsampling=2,
dropout_rate=drop_rate)
# ('conv4' group in WRN paper)
x = resnet_block(x_in=x,
num_filters=64*k_factor,
num_miniblocks=num_miniblocks,
downsampling=2,
dropout_rate=drop_rate)
# output block
probas = resnet_output_block(x_in=x)
# assemble
model = K.Model(inputs=input_layer, outputs=probas)
return model
Network Summary for WRN-40-2
Let's build WRN-40-2 with a dropout rate of 30% and view the model summary:
K.backend.clear_session()
model_WRN = build_WRN(40, 2, drop_rate=0.3)
model_WRN.summary()
Our network has a total of 2.27 million parameters, which is roughly half that of our earlier VGG model.
Visualizing the Network Graph
With a more complex model such as a ResNet, it's useful to check our network construction by viewing its computational graph or displaying the relationship between its layers in a more visually intuitive format. Fortunately, Keras has a function called plot_model
that allows us to easily accomplish this:
# visualizing the network
plot_model(model_WRN, to_file='model_plot.png',
show_shapes=True, show_layer_names=False)
Image('model_plot.png')
To train this WRN network, we call the same helper functions that we used for the VGG network. Quick and convenient! In fact, we could reuse these helper functions for any network model we define in Keras.
# reset the tensorflow graph
K.backend.clear_session()
# construct model
model_WRN = build_WRN(40, 2, drop_rate=0.3)
# define the optimizer initial settings
sgd = K.optimizers.SGD(
lr=0.1,
decay=0.0005,
momentum=0.9,
nesterov=True)
# initiate training
this_name = 'WRN_AugScheme1_FullTrainSet'
(this_history, this_dir) = compile_and_train(
model=model_WRN,
MODEL_NAME=this_name,
X_train_data=np.vstack((X_train, X_val)),
y_train_data=np.hstack((y_train, y_val)),
X_val_data=X_test,
y_val_data=y_test,
train_generator=train_datagen_1,
test_generator=test_datagen,
BATCH_SIZE=128,
EPOCHS=400,
optimizer=sgd,
lr_schedule_fn=None)
# function for evaluating performance
this_score = evaluate_performance(
model_obj=model_WRN,
model_dir=this_dir,
X_test_data=X_test,
y_test_data=y_test,
test_generator=test_datagen,
model_hist=this_history,
model_lr_hist=None)
This WRN model achieved a test accuracy of 89.3%, roughly 3% better than the VGG model we trained earlier, while using only half of the model parameters. Note that state-of-the-art performance for WRNs on CIFAR-10 is closer to 95% accuracy. Hence, with further parameter tuning, our model certainly has room for improvement. From the plots of accuracy and loss versus epoch, our metrics haven't quite plateaued yet, so we might also consider training our model for additonal epochs.
As before, let's view some sample predictions, just for fun!
# generate the predictions
test_generator = test_datagen.flow(X_test,
get_onehot(y_test),
batch_size=1)
probas = model_WRN.predict_generator(test_generator,
steps=len(y_test))
preds_WRN = np.argmax(probas, axis=1)
# view correct predictions
view_predictions(model_WRN, preds_WRN, condition='Correct')
# view incorrect predictions
view_predictions(model_WRN, preds_WRN, condition='Incorrect')
Finally, to wrap things up, let's see how our model performs with a different image augmentation scheme. First we'll add some flips and rotations to the mix:
# reset the tensorflow graph
K.backend.clear_session()
# construct model
model_WRN = build_WRN(40, 2, drop_rate=0.3)
# define the optimizer initial settings
sgd = K.optimizers.SGD(
lr=0.1,
decay=0.0005,
momentum=0.9,
nesterov=True)
# initiate training
this_name = 'WRN_AugScheme2_FullTrainSet'
(this_history, this_dir) = compile_and_train(
model=model_WRN,
MODEL_NAME=this_name,
X_train_data=np.vstack((X_train, X_val)),
y_train_data=np.hstack((y_train, y_val)),
X_val_data=X_test,
y_val_data=y_test,
train_generator=train_datagen_2,
test_generator=test_datagen,
BATCH_SIZE=128,
EPOCHS=400,
optimizer=sgd,
lr_schedule_fn=None)
# function for evaluating performance
this_score = evaluate_performance(
model_obj=model_WRN,
model_dir=this_dir,
X_test_data=X_test,
y_test_data=y_test,
test_generator=test_datagen,
model_hist=this_history,
model_lr_hist=None)
We see that the original image augmentation scheme was better - in this case we clearly have poorer model generalization as indicated by the significantly wider gap between the training and validation curves.
Let's also examine what happens if we normalize our data based on mean and standard deviation, rather than via ZCA whitening:
# define a new image data generator
train_datagen_3 = ImageDataGenerator(
featurewise_center=True,
samplewise_center=False,
featurewise_std_normalization=True, # ADDED
samplewise_std_normalization=False,
zca_whitening=False, # MODIFIED
zca_epsilon=1e-06,
rotation_range=0.0,
width_shift_range=0.25,
height_shift_range=0.25,
brightness_range=None,
shear_range=0.0,
zoom_range=0.0,
channel_shift_range=0.0,
fill_mode='reflect',
cval=0.0,
horizontal_flip=True,
vertical_flip=False,
rescale=None, # REVERTED
preprocessing_function=None,
data_format='channels_last',
validation_split=0.0)
# modify the test set generator to use std/mean
# standardization
test_datagen_2 = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True, # ADDED
zca_whitening=False, # MODIFIED
rescale=None, # REVERTED
data_format='channels_last')
train_datagen_3.fit(np.vstack((X_train, X_val)))
test_datagen_2.fit(np.vstack((X_train, X_val)))
# reset the tensorflow graph
K.backend.clear_session()
# construct model
model_WRN = build_WRN(40, 2, drop_rate=0.3)
# define the optimizer initial settings
sgd = K.optimizers.SGD(
lr=0.1,
decay=0.0005,
momentum=0.9,
nesterov=True)
# initiate training
this_name = 'WRN_AugScheme4(StdNorm)_FullTrainSet'
(this_history, this_dir) = compile_and_train(
model=model_WRN,
MODEL_NAME=this_name,
X_train_data=np.vstack((X_train, X_val)),
y_train_data=np.hstack((y_train, y_val)),
X_val_data=X_test,
y_val_data=y_test,
train_generator=train_datagen_3,
test_generator=test_datagen_2,
BATCH_SIZE=128,
EPOCHS=400,
optimizer=sgd,
lr_schedule_fn=None)
# function for evaluating performance
this_score = evaluate_performance(
model_obj=model_WRN,
model_dir=this_dir,
X_test_data=X_test,
y_test_data=y_test,
test_generator=test_datagen_3,
model_hist=this_history,
model_lr_hist=None)
In this case, the performance difference between ZCA and std/mean normalization was negligible.
In this tutorial we saw how Keras could be used to rapidly build complex machine learning models from the ground up. We first trained a VGG model using the Sequential()
model API, and then implemented a more sophisticated Wide Residual Network using the Functional API.
As well-put by Keras creator Francois Chollet, "your final models will typically share little in common with the solutions you envisioned when first approaching the problem ...a-priori plans basically never survive confrontation with experimental reality." So get out there, build 'em, and iterate!