Tensorflow Callback object execute different functions on different stages of training process. For instance, after one training epoch, before prediction on validation set, etc. This provides various useful functions that help us keep track of our learning process. In this section, we introduce some commonly used instances, including EarlyStopping, LearningRateScheduler, ModelCheckpoint, TensorBoard.
EarlyStopping(monitor=metrics, min_delta=delta, patience=patience) API
Early stopping is a useful callback that we can use to terminate the learning process if the model does not improve at least delta on the monitor metrics in n=patience epochs.
es_callback = tf.keras.callbacks.EarlyStopping(
monitor='val_sparse_categorical_accuracy',
min_delta=0.05,
patience=2
)
LearningRateScheduler(scheduler_func) API
Modify learning rate based on the epoch index. The scheduler_func takes two parameter as the input and return the modified learning rate. The first parameter is the index of epoch, the second parameter is the learning rate used.
def scheduler(epoch, lr):
if epoch < 10:
return lr
else:
return lr * tf.math.exp(-0.1)
ls_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
ModelCehckpoing(filepath, monitor=metrics, save_best_only=False, save_weights_only=False) API
Save the model (or weight of model) to the given filepath.
mc_callback = tf.keras.callbacks.ModelCheckpoint(
'logs/best_model.h5',
monitor='val_sparse_categorical_accuracy',
save_best_only=True
)
TensorflowBoard(log_dir) API
Create Tensorboard to visualize the training process. Use the following command to get access to the TensorBoard.
$ tensorboard --logdir=logs
tb_callback = tf.keras.callbacks.TensorBoard(log_dir='logs')
class MLP(tf.keras.Model):
def __init__(self, encoding_dim, activation='relu'):
super().__init__()
self.encoding_dim = encoding_dim
self.activation = activation
self.network = tf.keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28, 1)),
keras.layers.Dense(self.encoding_dim, activation=self.activation),
keras.layers.Dense(10)
])
def call(self, x):
return self.network(x)
mnist = tfds.image.MNIST()
mnist_data = mnist.as_dataset(batch_size=-1, shuffle_files=True)
mnist_train, mnist_test = mnist_data["train"], mnist_data["test"]
model = MLP(encoding_dim=128)
model.compile(optimizer=optimizers.Adam(),
loss=losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[metrics.SparseCategoricalAccuracy()])
history = model.fit(
mnist_train['image'],
mnist_train['label'],
epochs=10,
validation_split=0.25,
callbacks=[es_callback]
)