20行代码解决银行手写数字识别

20行代码解决银行手写数字识别

上述仅仅使用20行代码[1],就完成了手写数字的训练。而且准确率达到97.7%。不管手写数字是个什么鬼,反正完成了第一个神经网络的训练,而且结果还不错。

开心之后,我们静下心来想想,我们到底做了什么,达到这样的效果。那么,这些代码背后的运作机制是什么。

1. 前序

麻雀虽小,五脏俱全。这20行代码,简单表明了神经网络的训练过程:

  1. 加载数据
  2. 模型结构
  3. 模型训练前的参数配置
  4. 模型训练和评估

2. 加载数据

代码中加载了tensorflow中自带的MNist。主要有两个数据集:训练集和测试集。其中训练集有6000个样本,测试集有1000个样本。详细的数据展示可以查看扩展内容数据展示

3. 模型结构

以下几行代码就构成了整个模型,而且准确率还高达97.7%。

上述代码是由4层构成的,分别是平铺层Dense层Dropout层Dense层。并用Sequential将层与层之间关联起来。

其中平铺层(Flatten)就是将(28×28)的矩阵平铺成(784,)的向量。Dropout层是以一定比例随机停止神经元的连接。

前向传播---层层解读

层1:平铺层

该行代码就是将(28,28)的图像平铺成一个(784,)的向量,为后续Dense层的计算作准备。

层2:Dense层

$$ \begin{align}X_1 &= \{x_0, x_1, x_2, …, x_{784}\}, \\X&: 神经元的输入;\\W_1 &= \{w_0, w_1, w_2, …, w_{784}\} = matrix_{(784, 128)},\\W_1&: 偏置和权重;\\relu(z)&: 激活函数,表示对和的非线性变换,其本质就是对输入信息的选择。\end{align} $$
其中, 激活函数ReLU是目前常用而且简单的激活函数,搭建网络时推荐优先使用它。 其公式为: $$ ReLU(x) =\left\{\begin{aligned}x & & {x >=0}\\0 & & {x <0}​\end{aligned}\right. $$

其函数图像为:

激活函数是在神经网络中起到非常重要的作用,常用的有sigmoid函数、Tanh函数、Leaky-ReLU函数和ELU等。详细可以查看常用的激活函数

层3:Dropout层​

如下图[2]所示,左图是一般神经网络,右图是Dropout效果。其作用就是以一定概率随机的停止神经元工作,进而达到减少过拟合,提升泛化能力的作用。它是正则化的一种方式。

层4:Dense层

$$ \begin{align}H_2 &= \{h_0, h_1, h_2, …, h_{128}\}, \\X&: 神经元的输入;\\W_2 &= \{w_0, w_1, w_2, …, w_{10}\} = maxtrix_{(128, 10)},\\W_2&: 偏置和权重;\\​\sigma(z)&: 激活函数,表示对输出的非线性变换,其本质归一化。\end{align} $$

4. 损失函数(loss function)

损失函数本质上就是衡量预测与真实不一致的程度,而由于衡量的方式不同,产生了不同的损失函数。基础的有交叉熵损失函数(cross entropy loss function)和均方差损失函数(squre error loss funcion)。本代码中使用的是交叉熵损失函数。

对应代码如下:

交叉熵损失函数适于KL散度密切相关。感兴趣的同学可以扩展阅读,了解公式起源交叉熵与均方差

5. 优化方法

优化算法是在反向传播过程中,不断优化权重和偏置,进而最小化损失函数的方法。优化算法也各有不同。我们选用的是`adam`。它值一种`自适应`的优化算法。想进一步探索各个优化算法异同的同学可以参考延伸阅读Adam优化算法。

6.训练与评估

代码对样本训练了5次,之后利用测试集进行对每次训练后的模型进行评估。损失越小,精确度越高,模型就越好。

训练

模型训练的过程,其实就是不断更新权重和偏置的过程,更新后的权重和偏置使得损失函数达到最小。其训练过程如下:

  1. 利用默认方式”glorot_uniform”初始化权重;
  2. 循环以下两个步骤,直到收敛:
    • 计算梯度:\(\frac{\partial J(W)}{\partial W}\)
    • 更新梯度:  \( W \leftarrow W – \eta \frac{\partial J(W)}{\partial W} \)
  3. 返回权重

每次更新权重的过程,就是反向传播的过程。具体的推导公式详见反向传播

评估

评估的方式是计算测试集的损失和准确率。

准确度计算公式: \[ accuracy = \frac{{\text{预测正确的样本数}}}{总的样本数} \]

7.结果

训练结果如下:

Epoch 1/5
60000/60000 [==============================] - 3s 53us/sample - loss: 0.2977 - accuracy: 0.9133 - val_loss: 0.1356 - val_accuracy: 0.9594
Epoch 2/5
60000/60000 [==============================] - 3s 44us/sample - loss: 0.1428 - accuracy: 0.9582 - val_loss: 0.1012 - val_accuracy: 0.9690
Epoch 3/5
60000/60000 [==============================] - 3s 45us/sample - loss: 0.1043 - accuracy: 0.9681 - val_loss: 0.0843 - val_accuracy: 0.9733
Epoch 4/5
60000/60000 [==============================] - 3s 43us/sample - loss: 0.0862 - accuracy: 0.9736 - val_loss: 0.0764 - val_accuracy: 0.9764
Epoch 5/5
60000/60000 [==============================] - 3s 42us/sample - loss: 0.0752 - accuracy: 0.9768 - val_loss: 0.0759 - val_accuracy: 0.9768
10000/1 - 0s - loss: 0.0383 - accuracy: 0.9768