穷人深度学习:MLP

先来解释一下标题.穷人指的是我们的程序在 CPU 上跑,而且跑的是我们自己写的 C++ 程序,而不是在 GPU 上运行 PyTorch 之类的深度学习框架.MLP 是最简单、最容易实现的人工神经网络层,所以我们先从这开始.

计算图

计算图指的是张量计算的依赖关系图.在前馈神经网络中,这张图会是一幅 DAG.我们会提出(抄)我们想要拟合的函数及拟合模型,然后我们就可以用 loss 函数来量化模型的拟合程度,并优化这个 loss 函数的值.

任务

我们从最经典的 MNIST 开始.输入是转换为向量的 28*28 灰度图像,输出是一个数字.我们将使用 one-hot encoding 将这个数字转化成一个 10 维的向量,然后用 cross entropy 作为 loss 函数.这些函数的定义如下:

L=iyilogyi\providecommand\Gal{}\renewcommand\Gal[0]{\operatorname{Gal}}\providecommand\tr{}\renewcommand\tr[0]{\operatorname{tr}}\providecommand\GL{}\renewcommand\GL[0]{\operatorname{GL}}\providecommand\SL{}\renewcommand\SL[0]{\operatorname{SL}}\providecommand\PSL{}\renewcommand\PSL[0]{\operatorname{PSL}}\providecommand\SO{}\renewcommand\SO[0]{\operatorname{SO}}\providecommand\SU{}\renewcommand\SU[0]{\operatorname{SU}}\providecommand\im{}\renewcommand\im[0]{\operatorname{im}}\providecommand\cof{}\renewcommand\cof[0]{\operatorname{cof}}\providecommand\End{}\renewcommand\End[0]{\operatorname{End}}\providecommand\Tor{}\renewcommand\Tor[0]{\operatorname{Tor}}\providecommand\rk{}\renewcommand\rk[0]{\operatorname{rk}}\providecommand\Hom{}\renewcommand\Hom[0]{\operatorname{Hom}}\providecommand\diag{}\renewcommand\diag[0]{\operatorname{diag}}\providecommand\vspan{}\renewcommand\vspan[0]{\operatorname{span}}\providecommand\lcm{}\renewcommand\lcm[0]{\operatorname{lcm}}\providecommand\id{}\renewcommand\id[0]{\operatorname{id}}\providecommand\Ab{}\renewcommand\Ab[0]{\textsf{Ab}}\providecommand\Fld{}\renewcommand\Fld[0]{\textsf{Fld}}\providecommand\Mod{}\renewcommand\Mod[1]{#1\textsf{-Mod}}\providecommand\Grp{}\renewcommand\Grp[0]{\textsf{Grp}}\providecommand\dSet{}\renewcommand\dSet[1]{#1\textsf{-Set}}\providecommand\Set{}\renewcommand\Set[0]{\textsf{Set}}\providecommand\SetStar{}\renewcommand\SetStar[0]{\textsf{Set*}}\providecommand\Vect{}\renewcommand\Vect[1]{#1\textsf{-Vect}}\providecommand\Alg{}\renewcommand\Alg[1]{#1\textsf{-Alg}}\providecommand\Ring{}\renewcommand\Ring[0]{\textsf{Ring}}\providecommand\R{}\renewcommand\R[0]{\mathbb{R}}\providecommand\C{}\renewcommand\C[0]{\mathbb{C}}\providecommand\N{}\renewcommand\N[0]{\mathbb{N}}\providecommand\Z{}\renewcommand\Z[0]{\mathbb{Z}}\providecommand\Q{}\renewcommand\Q[0]{\mathbb{Q}}\providecommand\F{}\renewcommand\F[0]{\mathbb{F}}\providecommand\sfC{}\renewcommand\sfC[0]{\mathsf{C}}\providecommand\vphi{}\renewcommand\vphi[0]{\varphi}L=-\sum_iy_i\log y'_i

其中 αi\providecommand\Gal{}\renewcommand\Gal[0]{\operatorname{Gal}}\providecommand\tr{}\renewcommand\tr[0]{\operatorname{tr}}\providecommand\GL{}\renewcommand\GL[0]{\operatorname{GL}}\providecommand\SL{}\renewcommand\SL[0]{\operatorname{SL}}\providecommand\PSL{}\renewcommand\PSL[0]{\operatorname{PSL}}\providecommand\SO{}\renewcommand\SO[0]{\operatorname{SO}}\providecommand\SU{}\renewcommand\SU[0]{\operatorname{SU}}\providecommand\im{}\renewcommand\im[0]{\operatorname{im}}\providecommand\cof{}\renewcommand\cof[0]{\operatorname{cof}}\providecommand\End{}\renewcommand\End[0]{\operatorname{End}}\providecommand\Tor{}\renewcommand\Tor[0]{\operatorname{Tor}}\providecommand\rk{}\renewcommand\rk[0]{\operatorname{rk}}\providecommand\Hom{}\renewcommand\Hom[0]{\operatorname{Hom}}\providecommand\diag{}\renewcommand\diag[0]{\operatorname{diag}}\providecommand\vspan{}\renewcommand\vspan[0]{\operatorname{span}}\providecommand\lcm{}\renewcommand\lcm[0]{\operatorname{lcm}}\providecommand\id{}\renewcommand\id[0]{\operatorname{id}}\providecommand\Ab{}\renewcommand\Ab[0]{\textsf{Ab}}\providecommand\Fld{}\renewcommand\Fld[0]{\textsf{Fld}}\providecommand\Mod{}\renewcommand\Mod[1]{#1\textsf{-Mod}}\providecommand\Grp{}\renewcommand\Grp[0]{\textsf{Grp}}\providecommand\dSet{}\renewcommand\dSet[1]{#1\textsf{-Set}}\providecommand\Set{}\renewcommand\Set[0]{\textsf{Set}}\providecommand\SetStar{}\renewcommand\SetStar[0]{\textsf{Set*}}\providecommand\Vect{}\renewcommand\Vect[1]{#1\textsf{-Vect}}\providecommand\Alg{}\renewcommand\Alg[1]{#1\textsf{-Alg}}\providecommand\Ring{}\renewcommand\Ring[0]{\textsf{Ring}}\providecommand\R{}\renewcommand\R[0]{\mathbb{R}}\providecommand\C{}\renewcommand\C[0]{\mathbb{C}}\providecommand\N{}\renewcommand\N[0]{\mathbb{N}}\providecommand\Z{}\renewcommand\Z[0]{\mathbb{Z}}\providecommand\Q{}\renewcommand\Q[0]{\mathbb{Q}}\providecommand\F{}\renewcommand\F[0]{\mathbb{F}}\providecommand\sfC{}\renewcommand\sfC[0]{\mathsf{C}}\providecommand\vphi{}\renewcommand\vphi[0]{\varphi}\alpha_i 是层激活函数,这里我们用 ReLU 函数,它的定义是 R(x)=max(x,0)\providecommand\Gal{}\renewcommand\Gal[0]{\operatorname{Gal}}\providecommand\tr{}\renewcommand\tr[0]{\operatorname{tr}}\providecommand\GL{}\renewcommand\GL[0]{\operatorname{GL}}\providecommand\SL{}\renewcommand\SL[0]{\operatorname{SL}}\providecommand\PSL{}\renewcommand\PSL[0]{\operatorname{PSL}}\providecommand\SO{}\renewcommand\SO[0]{\operatorname{SO}}\providecommand\SU{}\renewcommand\SU[0]{\operatorname{SU}}\providecommand\im{}\renewcommand\im[0]{\operatorname{im}}\providecommand\cof{}\renewcommand\cof[0]{\operatorname{cof}}\providecommand\End{}\renewcommand\End[0]{\operatorname{End}}\providecommand\Tor{}\renewcommand\Tor[0]{\operatorname{Tor}}\providecommand\rk{}\renewcommand\rk[0]{\operatorname{rk}}\providecommand\Hom{}\renewcommand\Hom[0]{\operatorname{Hom}}\providecommand\diag{}\renewcommand\diag[0]{\operatorname{diag}}\providecommand\vspan{}\renewcommand\vspan[0]{\operatorname{span}}\providecommand\lcm{}\renewcommand\lcm[0]{\operatorname{lcm}}\providecommand\id{}\renewcommand\id[0]{\operatorname{id}}\providecommand\Ab{}\renewcommand\Ab[0]{\textsf{Ab}}\providecommand\Fld{}\renewcommand\Fld[0]{\textsf{Fld}}\providecommand\Mod{}\renewcommand\Mod[1]{#1\textsf{-Mod}}\providecommand\Grp{}\renewcommand\Grp[0]{\textsf{Grp}}\providecommand\dSet{}\renewcommand\dSet[1]{#1\textsf{-Set}}\providecommand\Set{}\renewcommand\Set[0]{\textsf{Set}}\providecommand\SetStar{}\renewcommand\SetStar[0]{\textsf{Set*}}\providecommand\Vect{}\renewcommand\Vect[1]{#1\textsf{-Vect}}\providecommand\Alg{}\renewcommand\Alg[1]{#1\textsf{-Alg}}\providecommand\Ring{}\renewcommand\Ring[0]{\textsf{Ring}}\providecommand\R{}\renewcommand\R[0]{\mathbb{R}}\providecommand\C{}\renewcommand\C[0]{\mathbb{C}}\providecommand\N{}\renewcommand\N[0]{\mathbb{N}}\providecommand\Z{}\renewcommand\Z[0]{\mathbb{Z}}\providecommand\Q{}\renewcommand\Q[0]{\mathbb{Q}}\providecommand\F{}\renewcommand\F[0]{\mathbb{F}}\providecommand\sfC{}\renewcommand\sfC[0]{\mathsf{C}}\providecommand\vphi{}\renewcommand\vphi[0]{\varphi}R(x)=\max(x,0)Li\providecommand\Gal{}\renewcommand\Gal[0]{\operatorname{Gal}}\providecommand\tr{}\renewcommand\tr[0]{\operatorname{tr}}\providecommand\GL{}\renewcommand\GL[0]{\operatorname{GL}}\providecommand\SL{}\renewcommand\SL[0]{\operatorname{SL}}\providecommand\PSL{}\renewcommand\PSL[0]{\operatorname{PSL}}\providecommand\SO{}\renewcommand\SO[0]{\operatorname{SO}}\providecommand\SU{}\renewcommand\SU[0]{\operatorname{SU}}\providecommand\im{}\renewcommand\im[0]{\operatorname{im}}\providecommand\cof{}\renewcommand\cof[0]{\operatorname{cof}}\providecommand\End{}\renewcommand\End[0]{\operatorname{End}}\providecommand\Tor{}\renewcommand\Tor[0]{\operatorname{Tor}}\providecommand\rk{}\renewcommand\rk[0]{\operatorname{rk}}\providecommand\Hom{}\renewcommand\Hom[0]{\operatorname{Hom}}\providecommand\diag{}\renewcommand\diag[0]{\operatorname{diag}}\providecommand\vspan{}\renewcommand\vspan[0]{\operatorname{span}}\providecommand\lcm{}\renewcommand\lcm[0]{\operatorname{lcm}}\providecommand\id{}\renewcommand\id[0]{\operatorname{id}}\providecommand\Ab{}\renewcommand\Ab[0]{\textsf{Ab}}\providecommand\Fld{}\renewcommand\Fld[0]{\textsf{Fld}}\providecommand\Mod{}\renewcommand\Mod[1]{#1\textsf{-Mod}}\providecommand\Grp{}\renewcommand\Grp[0]{\textsf{Grp}}\providecommand\dSet{}\renewcommand\dSet[1]{#1\textsf{-Set}}\providecommand\Set{}\renewcommand\Set[0]{\textsf{Set}}\providecommand\SetStar{}\renewcommand\SetStar[0]{\textsf{Set*}}\providecommand\Vect{}\renewcommand\Vect[1]{#1\textsf{-Vect}}\providecommand\Alg{}\renewcommand\Alg[1]{#1\textsf{-Alg}}\providecommand\Ring{}\renewcommand\Ring[0]{\textsf{Ring}}\providecommand\R{}\renewcommand\R[0]{\mathbb{R}}\providecommand\C{}\renewcommand\C[0]{\mathbb{C}}\providecommand\N{}\renewcommand\N[0]{\mathbb{N}}\providecommand\Z{}\renewcommand\Z[0]{\mathbb{Z}}\providecommand\Q{}\renewcommand\Q[0]{\mathbb{Q}}\providecommand\F{}\renewcommand\F[0]{\mathbb{F}}\providecommand\sfC{}\renewcommand\sfC[0]{\mathsf{C}}\providecommand\vphi{}\renewcommand\vphi[0]{\varphi}L_i 是层传递函数,是左乘一个合适大小的矩阵再加合适大小的 bias.

MNIST 数据集的网站上记录了各种方案训练后的成绩,我们这里实现一个 500+150 HU 的三层神经网络.

优化

计算上优化函数取值的方法有很多种,但计算图的结构让我们能够简易地获得输出对每个参数的偏导.最简单的想法是每次让参数向量向偏导向量的反方向走一点,同时偏导可以对多组训练数据取平均值,这就是 SGD 算法.我实现了 Adam 算法,至少对这个任务来说比 SGD 收敛快得多.

实现

我定义了一个 tensor_t 结构体用于存放张量的形状和数据,然后定义了若干节点,从公式中可以看出我们需要的结点大致是:

可以实现一个具有 out, acc, calculate(), differentiate() 的基类,然后让结点类型继承这个类.之后,可以将代码中的 for 循环改写成 std::transform,使用 execution_policy 可以做到一部分的并行化.当然,最终并行化还是需要在多个核心上复制计算图并计算不同的训练数据的梯度,再汇总.

实例代码在 仓库 里,有很多重复代码.