Jupyter Notebook

Tensorflow Callback object execute different functions on different stages of training process. For instance, after one training epoch, before prediction on validation set, etc. This provides various useful functions that help us keep track of our learning process. In this section, we introduce some commonly used instances, including EarlyStopping, LearningRateScheduler, ModelCheckpoint, TensorBoard.

Early Stopping

EarlyStopping(monitor=metrics, min_delta=delta, patience=patience) API

Early stopping is a useful callback that we can use to terminate the learning process if the model does not improve at least delta on the monitor metrics in n=patience epochs.

es_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_sparse_categorical_accuracy', 
    min_delta=0.05, 
    patience=2
)

Learning Rate Scheduler

LearningRateScheduler(scheduler_func) API

Modify learning rate based on the epoch index. The scheduler_func takes two parameter as the input and return the modified learning rate. The first parameter is the index of epoch, the second parameter is the learning rate used.

def scheduler(epoch, lr):
    if epoch < 10:
        return lr
    else:
        return lr * tf.math.exp(-0.1)

ls_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

Model Checkpoint

ModelCehckpoing(filepath, monitor=metrics, save_best_only=False, save_weights_only=False) API

Save the model (or weight of model) to the given filepath.

mc_callback = tf.keras.callbacks.ModelCheckpoint(
    'logs/best_model.h5',
    monitor='val_sparse_categorical_accuracy',
    save_best_only=True
)

Tensorboard

TensorflowBoard(log_dir) API Create Tensorboard to visualize the training process. Use the following command to get access to the TensorBoard.

$ tensorboard --logdir=logs
tb_callback = tf.keras.callbacks.TensorBoard(log_dir='logs')

Example

class MLP(tf.keras.Model):
    def __init__(self, encoding_dim, activation='relu'):
        super().__init__()
        self.encoding_dim = encoding_dim
        self.activation = activation
        self.network = tf.keras.Sequential([
            keras.layers.Flatten(input_shape=(28, 28, 1)),
            keras.layers.Dense(self.encoding_dim, activation=self.activation),
            keras.layers.Dense(10)
        ])

    def call(self, x):
        return self.network(x)

mnist = tfds.image.MNIST()
mnist_data = mnist.as_dataset(batch_size=-1, shuffle_files=True)
mnist_train, mnist_test = mnist_data["train"], mnist_data["test"]

model = MLP(encoding_dim=128)
model.compile(optimizer=optimizers.Adam(), 
              loss=losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=[metrics.SparseCategoricalAccuracy()])

history = model.fit(
    mnist_train['image'],
    mnist_train['label'],
    epochs=10,
    validation_split=0.25,
    callbacks=[es_callback]
)