Keras 提供 Callback 接口来追踪训练过程中的每一步结果,包括每一个 batch 和每一个 epoch。虽然名为“回调函数”,但实际上想要扩展这功能需要继承 keras.callbacks.Callback 类,该类提供两个与模型训练过程相关的属性:

通过这一接口可以实时可视化 fit 过程中每一个 batch 和每一个 epoch 迭代过程中的误差大小变化。以《Neural Networks and Deep Learning - Chap3 Improving the way neural networks learn》为例,假设我们要训练一个最简单的神经网络:

http://neuralnetworksanddeeplearning.com/images/tikz28.png

这个只有一个神经元的神经网络只有一个权重 w 和一个偏置 b 两个待训练的参数,假设要训练的数据只有 (1, 0),在这里比较 MSE 和 Cross Entropy 两种代价函数的学习效果。首先构建这个模型:

from keras import Sequential, initializers, optimizers
from keras.layers import Activation, Dense

import numpy as np

def viz_keras_fit(w, b, runtime_plot=False, 
								loss="mean_squared_error", 
								act="sigmoid"):
    d = DrawCallback(runtime_plot=runtime_plot)
    
    # 初始化参数
    w = initializers.Constant([w])
    b = initializers.Constant([b])

    x_train, y_train = np.array([1]), np.array([0])
    
    model = Sequential()
    model.add(Dense(1, 
        activation=act,
        input_shape=(1,),
        kernel_initializer=w,
        bias_initializer=b))

    # Learning Rate = 0.15
    sgd = optimizers.SGD(lr=0.15)
    model.compile(optimizer=sgd, loss=loss)

    model.fit(x = x_train,
        y = y_train,
        epochs=150,
        verbose=0,
        callbacks=[d]) # Callback List
    return d

分别以 (w = 0.6, b = 0.9)(w = 2, b = 2) 作为初始参数,以 MSE 作为 loss function 训练模型:

https://s3-us-west-2.amazonaws.com/secure.notion-static.com/b6457326-9ae1-46ed-a247-d246ec2fa8a1/mse_sigmoid_0609.gif

https://s3-us-west-2.amazonaws.com/secure.notion-static.com/74e0e29d-e9fd-473d-ac76-2b8b055e382e/mse_sigmoid_22.gif

初始参数仍然是 (2, 2) 换成 Cross Entropy 作为 loss function 之后:

https://s3-us-west-2.amazonaws.com/secure.notion-static.com/def7a144-9ccd-4ee3-aa3f-44201c2611ce/crossentropy_sigmoid_22.gif

虽然实现了实时可视化,但绘图所用的时间可能比一个 epoch 耗时更久,因此先记录每一步的 loss 再绘图会更好一些:

https://s3-us-west-2.amazonaws.com/secure.notion-static.com/aa3c878f-98b9-4199-839f-fec7a407dfd3/runtime_plot_false.png

实时观察模型的学习情况可以帮助我们在初期选择损失函数、激活函数、模型结构以及超参数等。以下是 DrawCallback 的实现:

import pylab as pl
from IPython import display
from keras.callbacks import Callback

class DrawCallback(Callback):
    def __init__(self, runtime_plot=True):
        super().__init__()
        self.init_loss = None
        self.runtime_plot = runtime_plot
        
        self.xdata = []
        self.ydata = []
    def _plot(self, epoch=None):
        epochs = self.params.get("epochs")
        pl.ylim(0, int(self.init_loss*2))
        pl.xlim(0, epochs)
    
        pl.plot(self.xdata, self.ydata)
        pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs))
        pl.ylabel('Loss {:.4f}'.format(self.ydata[-1]))
        
    def _runtime_plot(self, epoch):
        self._plot(epoch)
        
        display.clear_output(wait=True)
        display.display(pl.gcf())
        pl.gcf().clear()
        
    def plot(self):
        self._plot()
        pl.show()
    
    def on_epoch_end(self, epoch, logs = None):
        logs = logs or {}
        loss = logs.get("loss")
        if self.init_loss is None:
            self.init_loss = loss
        self.xdata.append(epoch)
        self.ydata.append(loss)
        if self.runtime_plot:
            self._runtime_plot(epoch)