PyTorch 14:权值初始化

PyTorch 学习笔记

Posted by YEY on December 17, 2020

Lecture 14 权值初始化

在前几节课中,我们学习了如何搭建网络模型。在网络模型搭建好之后,有一个非常重要的步骤,就是对模型中的权值进行初始化:正确的权值初始化可以加快模型的收敛,而不适当的权值初始化可以会引发梯度消失或者爆炸,最终导致模型无法训练。本节课,我们将学习如何进行权值初始化。

1. 梯度消失与爆炸

这里,我们以上节课中提到的一个三层的全连接网络为例。我们来看一下第二个隐藏层中的权值 $W_2$ 的梯度是怎么求取的。

从公式角度来看,为了防止发生梯度消失或者爆炸,我们必须严格控制网络层输出值的尺度范围,即每个网络层的输出值不能太大或者太小。

代码示例

100 个线性层的简单叠加

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


class MLP(nn.Module):
    def __init__(self, neural_num, layers):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):
        for (i, linear) in enumerate(self.linears):
            x = linear(x)
        return x

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data)  # normal: mean=0, std=1


layer_nums = 100  # 网络层数
neural_nums = 256  # 每层的神经元个数
batch_size = 16  # 输入数据的 batch size

net = MLP(neural_nums, layer_nums)
net.initialize()

inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1
output = net(inputs)
print(output)

输出结果:

1
2
3
4
5
6
7
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<MmBackward>) 

我们发现 output 中的每一个值都是 nan,这表明我们的数据值可能非常大或者非常小,已经超出了当前精度能够表示的范围。我们可以修改一下 forward 函数,来看一下什么时候我们的数据变为了 nan。这里,我们采用标准差作为指标来衡量数据的尺度范围。首先我们打印出每层的标准差,接着进行一个 if 判断,如果 x 的标准差变为 nan 了则停止前向传播。

1
2
3
4
5
6
7
8
9
10
def forward(self, x):
    for (i, linear) in enumerate(self.linears):
        x = linear(x)

        print("layer:{}, std:{}".format(i, x.std()))  # 打印每层的标准差
        if torch.isnan(x.std()):
            print("output is nan in {} layers".format(i))
            break  # 如果 x 的标准差变为 nan 则停止前向传播

        return x

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
layer:0, std:15.959932327270508
layer:1, std:256.6237487792969
layer:2, std:4107.24560546875
layer:3, std:65576.8125
layer:4, std:1045011.875
layer:5, std:17110408.0
layer:6, std:275461408.0
layer:7, std:4402537984.0
layer:8, std:71323615232.0
layer:9, std:1148104736768.0
layer:10, std:17911758454784.0
layer:11, std:283574846619648.0
layer:12, std:4480599809064960.0
layer:13, std:7.196814275405414e+16
layer:14, std:1.1507761512626258e+18
layer:15, std:1.853110740188555e+19
layer:16, std:2.9677725826641455e+20
layer:17, std:4.780376223769898e+21
layer:18, std:7.613223480799065e+22
layer:19, std:1.2092652108825478e+24
layer:20, std:1.923257075956356e+25
layer:21, std:3.134467063655912e+26
layer:22, std:5.014437766285408e+27
layer:23, std:8.066615144249704e+28
layer:24, std:1.2392661553516338e+30
layer:25, std:1.9455688099759845e+31
layer:26, std:3.0238180658999113e+32
layer:27, std:4.950357571077011e+33
layer:28, std:8.150925520353362e+34
layer:29, std:1.322983152787379e+36
layer:30, std:2.0786820453988485e+37
layer:31, std:nan
output is nan in 31 layers
tensor([[        inf, -2.6817e+38,         inf,  ...,         inf,
                 inf,         inf],
        [       -inf,        -inf,  1.4387e+38,  ..., -1.3409e+38,
         -1.9659e+38,        -inf],
        [-1.5873e+37,         inf,        -inf,  ...,         inf,
                -inf,  1.1484e+38],
        ...,
        [ 2.7754e+38, -1.6783e+38, -1.5531e+38,  ...,         inf,
         -9.9440e+37, -2.5132e+38],
        [-7.7184e+37,        -inf,         inf,  ..., -2.6505e+38,
                 inf,         inf],
        [        inf,         inf,        -inf,  ...,        -inf,
                 inf,  1.7432e+38]], grad_fn=<MmBackward>)

可以看到,当进行到 31 层的时候,数据的标准差就已经变为 nan 了。我们看到,在第 31 层的时候,数据的值都非常大或者非常小,再往后传播,计算机当前的精度就已经没办法去表示这些特别大或者特别小的数据了。另外,可以看到每一层网络的标准差都是逐渐增大的,直到第 31 层,大约在 $10^{38} \sim 10^{39}$ 之间,而这已经超出了我们当前精度可以表示的数据范围。

下面我们通过方差的公式推导来观察为什么网络层输出的标准差会越来越大,最终超出可表示的范围。假设 $X$ 和 $Y$ 是两个相互独立的随机变量,我们知道:

\[\begin{aligned} \mathrm{E}(X*Y) &= \mathrm{E}(X) * \mathrm{E}(Y) \\[2ex] \mathrm{Var}(X) &= \mathrm{E}(X^2) - [\mathrm{E}(X)]^2 \\[2ex] \mathrm{Var}(X+Y) &= \mathrm{Var}(X) + \mathrm{Var}(Y) \end{aligned}\]

然后,我们有:

\[\mathrm{Var}(X*Y) = \mathrm{Var}(X) * \mathrm{Var}(Y) + \mathrm{Var}(X) * [\mathrm{E}(Y)]^2 + \mathrm{Var}(Y) * [\mathrm{E}(X)]^2\]

如果 $\mathrm{E}(X) =0,\;\mathrm{E}(Y) =0$,那么我们有:

\[\mathrm{Var}(X*Y) = \mathrm{Var}(X) * \mathrm{Var}(Y)\]

下面我们来计算网络层神经元的标准差:

\[H_{11} = \sum_{i=0}^{n}X_i * W_{1i}\]

由于 $X$ 和 $W$ 都是均值为 $0$,标准差为 $1$,我们有:

\[\begin{aligned} \mathrm{Var}(H_{11}) &= \sum_{i=0}^{n} \mathrm{Var}(X_{i}) * \mathrm{Var}(W_{1i}) \\ &= n * (1 * 1) \\ &= n \end{aligned}\]

所以,

\[\mathrm{Std}(H_{11}) = \sqrt{\mathrm{Var}(H_{11})} = \sqrt{n}\]

可以看到,第一个隐藏层中神经元的方差变为了 $n$,而输入 $X$ 的方差为 $1$。也就是说,经过第一个网络层 $H_1$ 的前向传播,数据的方差扩大了 $n$ 倍,标准差扩大了 $\sqrt{n}$ 倍。同理,如果继续传播到下一个隐藏层 $H_2$,通过公式推导,可知该层神经元的标准差为 $n$。这样不断传播下去,每经过一层,输出数据的尺度范围都将不断扩大 $\sqrt{n}$ 倍,最终将超出我们的精度可表示的范围,变为 nan

在代码中,我们设置的每层网络中神经元个数 $n=256$,所以 $\sqrt{n} = 16$。我们来看一下前面输出结果中的每个网络层输出的标准差是否符合这一规律:

  • 第 0 层数据标准差为 $15.959932327270508 \approx 16$
  • 第 1 层数据标准差为 $256.6237487792969 \approx 16^2=256$
  • 第 2 层数据标准差为 $4107.24560546875 \approx 16^3 = 4096$
  • 第 3 层数据标准差为 $65576.8125 \approx 16^4 = 65536$
  • ……

每经过一层,数据的标准差都会扩大 $16$ 倍,经过一层层传播后,数据的标准差将变得非常大,最终在第 31 层时超出了精度可表示的范围,即为 nan

下面我们将每层神经元个数修改为 $n=400$,所以 $\sqrt{n}=20$,观察结果是否会发生相应的变化:

1
2
3
layer_nums = 100  # 网络层数
neural_nums = 400  # 每层的神经元个数
batch_size = 16  # 输入数据的 batch size

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
layer:0, std:20.191545486450195
layer:1, std:406.2967834472656
layer:2, std:8196.0322265625
layer:3, std:164936.546875
layer:4, std:3324399.75
layer:5, std:65078964.0
layer:6, std:1294259712.0
layer:7, std:25718734848.0
layer:8, std:509478502400.0
layer:9, std:10142528569344.0
layer:10, std:204187744862208.0
layer:11, std:4146330289045504.0
layer:12, std:8.175371463688192e+16
layer:13, std:1.6178185228915835e+18
layer:14, std:3.201268126493075e+19
layer:15, std:6.43244420071468e+20
layer:16, std:1.2768073112864894e+22
layer:17, std:2.5327442663597998e+23
layer:18, std:4.97064812888673e+24
layer:19, std:9.969679340542473e+25
layer:20, std:1.9616922876332235e+27
layer:21, std:3.926491184057203e+28
layer:22, std:7.928349353787082e+29
layer:23, std:1.5731294716685355e+31
layer:24, std:3.156214979388958e+32
layer:25, std:6.18353463606124e+33
layer:26, std:1.2453666891690611e+35
layer:27, std:2.467429285844339e+36
layer:28, std:4.977222187097705e+37
layer:29, std:nan
output is nan in 29 layers
tensor([[-inf, inf, inf,  ..., -inf, nan, nan],
        [nan, nan, inf,  ..., -inf, -inf, nan],
        [nan, -inf, nan,  ..., inf, nan, nan],
        ...,
        [nan, -inf, -inf,  ..., -inf, nan, nan],
        [inf, -inf, nan,  ..., inf, -inf, nan],
        [inf, nan, inf,  ..., inf, nan, inf]], grad_fn=<MmBackward>)

可以看到:

  • 第 0 层数据标准差为 $20.191545486450195 \approx 20$
  • 第 1 层数据标准差为 $406.2967834472656 \approx 20^2=400$
  • 第 2 层数据标准差为 $8196.0322265625 \approx 20^3 = 8000$
  • 第 3 层数据标准差为 $164936.546875 \approx 20^4 = 160000$
  • ……

每经过一层,数据的标准差大约会扩大 $20$ 倍,最终在第 29 层时超出了精度可表示的范围,变为 nan

从前面的公式中可以看到,每个网络层输出数据的标准差由三个因素决定:网络层的神经元个数 $n$、输入值 $X$ 的方差 $\mathrm{Var}(X)$,以及网络层权值 $W$ 的方差 $\mathrm{Var}(W)$。因此,如果我们希望让网络层输出数据的方差保持尺度不变,那么我们必须令其方差等于 $1$,即:

\[\mathrm{Var}(H_{1}) = n * \mathrm{Var}(X) * \mathrm{Var}(W) = 1\]

因此,

\[\mathrm{Var}(W) = \dfrac{1}{n} \quad \Longrightarrow \quad \mathrm{Std}(W) = \sqrt{\dfrac{1}{n}}\]

所以,当我们将网络层权值的标准差设为 $\sqrt{\frac{1}{n}}$ 时,输出数据的标准差将变为 $1$。

下面我们修改代码,使用一个均值为 $0$,标准差为 $\sqrt{\frac{1}{n}}$ 的分布来初始化权值矩阵 $W$,观察网络层输出数据的标准差会如何变化:

1
2
3
4
 def initialize(self):
    for m in self.modules():
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))    # normal: mean=0, std=sqrt(1/n)

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
layer:0, std:0.9974957704544067
layer:1, std:1.0024365186691284
layer:2, std:1.002745509147644
layer:3, std:1.0006227493286133
layer:4, std:0.9966009855270386
layer:5, std:1.019859790802002
layer:6, std:1.026173710823059
layer:7, std:1.0250457525253296
layer:8, std:1.0378952026367188
layer:9, std:1.0441951751708984
layer:10, std:1.0181655883789062
layer:11, std:1.0074602365493774
layer:12, std:0.9948930144309998
layer:13, std:0.9987586140632629
layer:14, std:0.9981392025947571
layer:15, std:1.0045733451843262
layer:16, std:1.0055204629898071
layer:17, std:1.0122840404510498
layer:18, std:1.0076017379760742
layer:19, std:1.000280737876892
layer:20, std:0.9943006038665771
layer:21, std:1.012800931930542
layer:22, std:1.012657642364502
layer:23, std:1.018149971961975
layer:24, std:0.9776086211204529
layer:25, std:0.9592394828796387
layer:26, std:0.9317858815193176
layer:27, std:0.9534041881561279
layer:28, std:0.9811319708824158
layer:29, std:0.9953019022941589
layer:30, std:0.9773916006088257
layer:31, std:0.9655940532684326
layer:32, std:0.9270440936088562
layer:33, std:0.9329946637153625
layer:34, std:0.9311841726303101
layer:35, std:0.9354336261749268
layer:36, std:0.9492132067680359
layer:37, std:0.9679954648017883
layer:38, std:0.9849981665611267
layer:39, std:0.9982335567474365
layer:40, std:0.9616852402687073
layer:41, std:0.9439758658409119
layer:42, std:0.9631161093711853
layer:43, std:0.958673894405365
layer:44, std:0.9675614237785339
layer:45, std:0.9837557077407837
layer:46, std:0.9867278337478638
layer:47, std:0.9920817017555237
layer:48, std:0.9650403261184692
layer:49, std:0.9991624355316162
layer:50, std:0.9946174025535583
layer:51, std:0.9662044048309326
layer:52, std:0.9827387928962708
layer:53, std:0.9887880086898804
layer:54, std:0.9932605624198914
layer:55, std:1.0237400531768799
layer:56, std:0.9702046513557434
layer:57, std:1.0045380592346191
layer:58, std:0.9943899512290955
layer:59, std:0.9900636076927185
layer:60, std:0.99446702003479
layer:61, std:0.9768352508544922
layer:62, std:0.9797843098640442
layer:63, std:0.9951220750808716
layer:64, std:0.9980446696281433
layer:65, std:1.0086933374404907
layer:66, std:1.0276142358779907
layer:67, std:1.0429234504699707
layer:68, std:1.0197855234146118
layer:69, std:1.0319130420684814
layer:70, std:1.0540012121200562
layer:71, std:1.026781439781189
layer:72, std:1.0331352949142456
layer:73, std:1.0666675567626953
layer:74, std:1.0413838624954224
layer:75, std:1.0733673572540283
layer:76, std:1.0404183864593506
layer:77, std:1.0344083309173584
layer:78, std:1.0022705793380737
layer:79, std:0.99835205078125
layer:80, std:0.9732587337493896
layer:81, std:0.9777462482452393
layer:82, std:0.9753198623657227
layer:83, std:0.9938382506370544
layer:84, std:0.9472599029541016
layer:85, std:0.9511011242866516
layer:86, std:0.9737769961357117
layer:87, std:1.005651831626892
layer:88, std:1.0043526887893677
layer:89, std:0.9889539480209351
layer:90, std:1.0130352973937988
layer:91, std:1.0030947923660278
layer:92, std:0.9993206262588501
layer:93, std:1.0342745780944824
layer:94, std:1.031973123550415
layer:95, std:1.0413124561309814
layer:96, std:1.0817031860351562
layer:97, std:1.128799557685852
layer:98, std:1.1617802381515503
layer:99, std:1.2215303182601929
tensor([[-1.0696, -1.1373,  0.5047,  ..., -0.4766,  1.5904, -0.1076],
        [ 0.4572,  1.6211,  1.9659,  ..., -0.3558, -1.1235,  0.0979],
        [ 0.3908, -0.9998, -0.8680,  ..., -2.4161,  0.5035,  0.2814],
        ...,
        [ 0.1876,  0.7971, -0.5918,  ...,  0.5395, -0.8932,  0.1211],
        [-0.0102, -1.5027, -2.6860,  ...,  0.6954, -0.1858, -0.8027],
        [-0.5871, -1.3739, -2.9027,  ...,  1.6734,  0.5094, -0.9986]],
       grad_fn=<MmBackward>)

可以看到,第 99 层的输出值都在一个比较正常的范围,并且每一层输出数据的标准差都在 $1$ 附近,所以现在我们得到了一个比较理想的输出数据分布。代码实验的结果也验证了我们前面公式推导的正确性:通过采用合适的权值初始化方法,可以使得多层全连接网络的输出值的数据尺度维持在一定范围内,而不会变得过大或者过小。

在上面的例子中,我们通过权重初始化保证了每层输出数据的方差为 $1$,但是这里我们还没有考虑激活函数的存在。下面我们看一下 具有激活函数时的权值初始化。我们在前向传播 forward 函数中加入 tanh 激活函数:

1
2
3
4
5
6
7
8
9
10
11
def forward(self, x):
    for (i, linear) in enumerate(self.linears):
        x = linear(x)
        x = torch.tanh(x)

        print("layer:{}, std:{}".format(i, x.std()))  # 打印每层的标准差
        if torch.isnan(x.std()):
            print("output is nan in {} layers".format(i))
            break

    return x

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
layer:0, std:0.6273701786994934
layer:1, std:0.48910173773765564
layer:2, std:0.4099564850330353
layer:3, std:0.35637012124061584
layer:4, std:0.32117360830307007
layer:5, std:0.2981105148792267
layer:6, std:0.27730831503868103
layer:7, std:0.2589356303215027
layer:8, std:0.2468511462211609
layer:9, std:0.23721906542778015
layer:10, std:0.22171513736248016
layer:11, std:0.21079954504966736
layer:12, std:0.19820132851600647
layer:13, std:0.19069305062294006
layer:14, std:0.18555502593517303
layer:15, std:0.17953835427761078
layer:16, std:0.17485806345939636
layer:17, std:0.1702701896429062
layer:18, std:0.16508983075618744
layer:19, std:0.1591130942106247
layer:20, std:0.15480300784111023
layer:21, std:0.15263864398002625
layer:22, std:0.148549422621727
layer:23, std:0.14617665112018585
layer:24, std:0.13876432180404663
layer:25, std:0.13316625356674194
layer:26, std:0.12660598754882812
layer:27, std:0.12537942826747894
layer:28, std:0.12535445392131805
layer:29, std:0.12589804828166962
layer:30, std:0.11994210630655289
layer:31, std:0.11700887233018875
layer:32, std:0.11137297749519348
layer:33, std:0.11154612898826599
layer:34, std:0.10991233587265015
layer:35, std:0.10996390879154205
layer:36, std:0.10969001054763794
layer:37, std:0.10975216329097748
layer:38, std:0.11063200235366821
layer:39, std:0.11021336913108826
layer:40, std:0.10465587675571442
layer:41, std:0.10141163319349289
layer:42, std:0.1026025265455246
layer:43, std:0.10079070925712585
layer:44, std:0.10096712410449982
layer:45, std:0.10117629915475845
layer:46, std:0.10145658999681473
layer:47, std:0.09987485408782959
layer:48, std:0.09677786380052567
layer:49, std:0.099615179002285
layer:50, std:0.09867013245820999
layer:51, std:0.09398546814918518
layer:52, std:0.09388342499732971
layer:53, std:0.09352942556142807
layer:54, std:0.09336657077074051
layer:55, std:0.0948176234960556
layer:56, std:0.08856320381164551
layer:57, std:0.09024856984615326
layer:58, std:0.088644839823246
layer:59, std:0.08766943216323853
layer:60, std:0.08726289123296738
layer:61, std:0.08623495697975159
layer:62, std:0.08549778908491135
layer:63, std:0.0855521708726883
layer:64, std:0.0853666365146637
layer:65, std:0.08462794870138168
layer:66, std:0.0852193832397461
layer:67, std:0.08562126755714417
layer:68, std:0.08368431031703949
layer:69, std:0.08476374298334122
layer:70, std:0.0853630006313324
layer:71, std:0.08237560093402863
layer:72, std:0.08133518695831299
layer:73, std:0.08416958898305893
layer:74, std:0.08226992189884186
layer:75, std:0.08379074186086655
layer:76, std:0.08003697544336319
layer:77, std:0.07888862490653992
layer:78, std:0.07618380337953568
layer:79, std:0.07458437979221344
layer:80, std:0.07207276672124863
layer:81, std:0.07079190015792847
layer:82, std:0.0712786465883255
layer:83, std:0.07165777683258057
layer:84, std:0.06893909722566605
layer:85, std:0.0690247192978859
layer:86, std:0.07030878216028214
layer:87, std:0.07283661514520645
layer:88, std:0.07280214875936508
layer:89, std:0.07130246609449387
layer:90, std:0.07225215435028076
layer:91, std:0.0712454542517662
layer:92, std:0.07088854163885117
layer:93, std:0.0730612576007843
layer:94, std:0.07276967912912369
layer:95, std:0.07259567081928253
layer:96, std:0.07586522400379181
layer:97, std:0.07769150286912918
layer:98, std:0.07842090725898743
layer:99, std:0.08206238597631454
tensor([[-0.1103, -0.0739,  0.1278,  ..., -0.0508,  0.1544, -0.0107],
        [ 0.0807,  0.1208,  0.0030,  ..., -0.0385, -0.1887, -0.0294],
        [ 0.0321, -0.0833, -0.1482,  ..., -0.1133,  0.0206,  0.0155],
        ...,
        [ 0.0108,  0.0560, -0.1099,  ...,  0.0459, -0.0961, -0.0124],
        [ 0.0398, -0.0874, -0.2312,  ...,  0.0294, -0.0562, -0.0556],
        [-0.0234, -0.0297, -0.1155,  ...,  0.1143,  0.0083, -0.0675]],
       grad_fn=<TanhBackward>)

可以看到,随着网络层的前向传播,每层输出值的标准差越来越小,最终可能会导致 梯度消失,这并不是我们所希望看到的。

2. Xavier 初始化

参考文献Understanding the difficulty of training deep feedforward neural networks

针对上面具有激活函数情况的问题,2010 年 Xavier 在一篇论文中详细探讨了在具有激活函数的情况下应该如何初始化的问题。在文献中,结合方差一致性原则 (即让每层网络输出值的方差尽量在 $1$ 附近),同时作者对 Sigmoid、tanh 这类饱和激活函数进行分析。

方差一致性:保持数据尺度维持在恰当范围,通常方差为 $1$。

激活函数:饱和函数,如 Sigmoid、Tanh。

通过论文中的公式推导,我们可以得到以下两个等式:

\[n_i * \mathrm{Var}(W) =1\] \[n_{i+1} * \mathrm{Var}(W) =1\]

其中,$n_i$ 是输入层的神经元个数,$n_{i+1}$ 是输出层的神经元个数。即我们同时考虑了前向传播和反向传播过程中的数据尺度问题。

同时,结合方差一致性原则,我们可以得到权值 $W$ 的方差为:

\[\mathrm{Var}(W) = \dfrac{2}{n_i + n_{i+1}}\]

通常,Xavier 采用的是均匀分布。下面我们来推导均匀分布是上限和下限,这里我们假设上限是 $a$,那么下限为 $-a$,因为我们通常采用的是零均值,所以上下限之间是对称关系。

\[W \sim U[-a, a]\]

根据均匀分布的方差公式,我们得到:

\[\mathrm{Var}(W) = \dfrac{(-a-a)^2}{12} = \dfrac{(2a)^2}{12} = \dfrac{a^2}{3}\]

然后,我们有:

\[\dfrac{2}{n_i + n_{i+1}} = \dfrac{a^2}{3} \quad \Longrightarrow \quad a = \dfrac{\sqrt{6}}{\sqrt{n_i + n_{i+1}}}\]

所以,

\[W \sim U\left[- \dfrac{\sqrt{6}}{\sqrt{n_i + n_{i+1}}}, \dfrac{\sqrt{6}}{\sqrt{n_i + n_{i+1}}} \right]\]

代码示例

我们可以通过手动计算实现 Xavier 初始化:

1
2
3
4
5
6
7
def initialize(self):
    for m in self.modules():
        if isinstance(m, nn.Linear):
            a = np.sqrt(6 / (self.neural_num + self.neural_num))
            tanh_gain = nn.init.calculate_gain('tanh')  # 计算激活函数的增益
            a *= tanh_gain
            nn.init.uniform_(m.weight.data, -a, a)

另外,PyTorch 中也内置了 Xavier 初始化方法,其结果和我们手动计算的结果是一致的:

1
2
3
4
5
def initialize(self):
    for m in self.modules():
        if isinstance(m, nn.Linear):
            tanh_gain = nn.init.calculate_gain('tanh')
            nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
layer:0, std:0.7571136355400085
layer:1, std:0.6924336552619934
layer:2, std:0.6677976846694946
layer:3, std:0.6551960110664368
layer:4, std:0.655646800994873
layer:5, std:0.6536089777946472
layer:6, std:0.6500504612922668
layer:7, std:0.6465446949005127
layer:8, std:0.6456685662269592
layer:9, std:0.6414617896080017
layer:10, std:0.6423627734184265
layer:11, std:0.6509683728218079
layer:12, std:0.6584846377372742
layer:13, std:0.6530249118804932
layer:14, std:0.6528729796409607
layer:15, std:0.6523412466049194
layer:16, std:0.6534921526908875
layer:17, std:0.6540238261222839
layer:18, std:0.6477403044700623
layer:19, std:0.6469652652740479
layer:20, std:0.6441705822944641
layer:21, std:0.6484488248825073
layer:22, std:0.6512865424156189
layer:23, std:0.6525684595108032
layer:24, std:0.6531476378440857
layer:25, std:0.6488809585571289
layer:26, std:0.6533839702606201
layer:27, std:0.6482065320014954
layer:28, std:0.6471589803695679
layer:29, std:0.6553042531013489
layer:30, std:0.6560811400413513
layer:31, std:0.6522760987281799
layer:32, std:0.6499098539352417
layer:33, std:0.6568747758865356
layer:34, std:0.6544532179832458
layer:35, std:0.6535674929618835
layer:36, std:0.6508696675300598
layer:37, std:0.6428772807121277
layer:38, std:0.6495102643966675
layer:39, std:0.6479291319847107
layer:40, std:0.6470604538917542
layer:41, std:0.6513484716415405
layer:42, std:0.6503545045852661
layer:43, std:0.6458993554115295
layer:44, std:0.6517387628555298
layer:45, std:0.6520006060600281
layer:46, std:0.6539937257766724
layer:47, std:0.6537032723426819
layer:48, std:0.6516646146774292
layer:49, std:0.6535552740097046
layer:50, std:0.6464877724647522
layer:51, std:0.6491119265556335
layer:52, std:0.6455202102661133
layer:53, std:0.6520237326622009
layer:54, std:0.6531855463981628
layer:55, std:0.6627183556556702
layer:56, std:0.6544181108474731
layer:57, std:0.6501768827438354
layer:58, std:0.6510448455810547
layer:59, std:0.6549468040466309
layer:60, std:0.6529951691627502
layer:61, std:0.6515748500823975
layer:62, std:0.6453633904457092
layer:63, std:0.644793689250946
layer:64, std:0.6489539742469788
layer:65, std:0.6553947925567627
layer:66, std:0.6535270810127258
layer:67, std:0.6528791785240173
layer:68, std:0.6492816209793091
layer:69, std:0.6596571207046509
layer:70, std:0.6536712646484375
layer:71, std:0.6498764157295227
layer:72, std:0.6538681387901306
layer:73, std:0.64595627784729
layer:74, std:0.6543275117874146
layer:75, std:0.6525828838348389
layer:76, std:0.6462088227272034
layer:77, std:0.6534948945045471
layer:78, std:0.6461930871009827
layer:79, std:0.6457878947257996
layer:80, std:0.6481245160102844
layer:81, std:0.6496317386627197
layer:82, std:0.6516988277435303
layer:83, std:0.6485154032707214
layer:84, std:0.6395408511161804
layer:85, std:0.6498249173164368
layer:86, std:0.6510564088821411
layer:87, std:0.6505221724510193
layer:88, std:0.6573457717895508
layer:89, std:0.6529723405838013
layer:90, std:0.6536353230476379
layer:91, std:0.6497699022293091
layer:92, std:0.6459059715270996
layer:93, std:0.6459072232246399
layer:94, std:0.6530925631523132
layer:95, std:0.6515892148017883
layer:96, std:0.6434286832809448
layer:97, std:0.6425578594207764
layer:98, std:0.6407340168952942
layer:99, std:0.6442393660545349
tensor([[ 0.1133,  0.1239,  0.8211,  ...,  0.9411, -0.6334,  0.5155],
        [-0.9585, -0.2371,  0.8548,  ..., -0.2339,  0.9326,  0.0114],
        [ 0.9487, -0.2279,  0.8735,  ..., -0.9593,  0.7922,  0.6263],
        ...,
        [ 0.7257,  0.0800, -0.4440,  ..., -0.9589,  0.2604,  0.5402],
        [-0.9572,  0.5179, -0.8041,  ..., -0.4298, -0.6087,  0.9679],
        [ 0.6105,  0.3994,  0.1072,  ...,  0.3904, -0.5274,  0.0776]],
       grad_fn=<TanhBackward>)

可以看到,每层网络输出值的标准差都在 $0.65$ 左右,这表明每层网络的输出值都不会过大或者过小。并且,最后第 99 层的输出值也在一个比较正常的范围内。

3. Kaiming 初始化

虽然在 2010 年,Xavier 针对诸如 Sigmoid、tanh 这类饱和激活函数给出了有效的初始化方法。但是,也是在同一年 Xnet 出现之后,非饱和激活函数 ReLU 被广泛使用,由于非饱和函数的性质,Xavier 初始化方法将不再适用。

在下面的代码中,将激活函数改为 ReLU,并且仍然使用 Xavier 初始化方法,我们来观察一下网络层的输出:

1
2
3
4
5
6
7
8
9
10
11
def forward(self, x):
    for (i, linear) in enumerate(self.linears):
        x = linear(x)
        x = torch.relu(x)

        print("layer:{}, std:{}".format(i, x.std()))  # 打印每层的标准差
        if torch.isnan(x.std()):
            print("output is nan in {} layers".format(i))
            break

    return x

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
layer:0, std:0.9689465165138245
layer:1, std:1.0872339010238647
layer:2, std:1.2967971563339233
layer:3, std:1.4487521648406982
layer:4, std:1.8563750982284546
layer:5, std:2.2424941062927246
layer:6, std:2.679966449737549
layer:7, std:3.2177586555480957
layer:8, std:3.8579354286193848
layer:9, std:4.413454532623291
layer:10, std:5.518202781677246
layer:11, std:6.072154521942139
layer:12, std:7.441657543182373
layer:13, std:8.963356971740723
layer:14, std:10.747811317443848
layer:15, std:13.216470718383789
layer:16, std:15.070353507995605
layer:17, std:17.297853469848633
layer:18, std:19.603160858154297
layer:19, std:22.72492218017578
layer:20, std:27.811525344848633
layer:21, std:34.64209747314453
layer:22, std:43.16114044189453
layer:23, std:51.901859283447266
layer:24, std:59.3619384765625
layer:25, std:63.65275955200195
layer:26, std:61.95321273803711
layer:27, std:79.72232055664062
layer:28, std:99.41972351074219
layer:29, std:118.13148498535156
layer:30, std:128.12930297851562
layer:31, std:140.68907165527344
layer:32, std:165.6183319091797
layer:33, std:196.19956970214844
layer:34, std:214.2675323486328
layer:35, std:282.7183532714844
layer:36, std:317.1474304199219
layer:37, std:373.9003601074219
layer:38, std:412.70892333984375
layer:39, std:529.4519653320312
layer:40, std:532.9295654296875
layer:41, std:630.9380493164062
layer:42, std:762.7489624023438
layer:43, std:813.0692138671875
layer:44, std:1022.0352783203125
layer:45, std:1363.53759765625
layer:46, std:1734.6246337890625
layer:47, std:1899.3427734375
layer:48, std:2251.1640625
layer:49, std:2680.478759765625
layer:50, std:3370.64794921875
layer:51, std:4003.856201171875
layer:52, std:4598.98779296875
layer:53, std:5199.58447265625
layer:54, std:6399.32568359375
layer:55, std:8127.16064453125
layer:56, std:9794.875
layer:57, std:11728.7431640625
layer:58, std:15471.70703125
layer:59, std:19942.44921875
layer:60, std:22642.23046875
layer:61, std:28904.16796875
layer:62, std:37538.265625
layer:63, std:41843.19921875
layer:64, std:48306.828125
layer:65, std:56072.6171875
layer:66, std:59877.83984375
layer:67, std:57911.44140625
layer:68, std:68525.5859375
layer:69, std:78614.9609375
layer:70, std:103845.0
layer:71, std:121762.9375
layer:72, std:128452.984375
layer:73, std:146725.484375
layer:74, std:168575.125
layer:75, std:176617.84375
layer:76, std:202430.046875
layer:77, std:247756.625
layer:78, std:310793.875
layer:79, std:374327.34375
layer:80, std:456118.53125
layer:81, std:545246.25
layer:82, std:550071.8125
layer:83, std:653713.0
layer:84, std:831133.9375
layer:85, std:1045186.3125
layer:86, std:1184264.0
layer:87, std:1334159.5
layer:88, std:1589417.75
layer:89, std:1783507.25
layer:90, std:2239068.0
layer:91, std:2429546.0
layer:92, std:2928562.0
layer:93, std:2883037.75
layer:94, std:3230928.75
layer:95, std:3661650.75
layer:96, std:4741352.0
layer:97, std:5300345.0
layer:98, std:6797732.0
layer:99, std:7640650.0
tensor([[       0.0000,  3028736.2500, 12379591.0000,  ...,
          3593889.2500,        0.0000, 24658882.0000],
        [       0.0000,  2758786.0000, 11016991.0000,  ...,
          2970399.7500,        0.0000, 23173860.0000],
        [       0.0000,  2909416.2500, 13117423.0000,  ...,
          3867089.2500,        0.0000, 28463550.0000],
        ...,
        [       0.0000,  3913293.2500, 15489672.0000,  ...,
          5777762.0000,        0.0000, 33226552.0000],
        [       0.0000,  3673798.2500, 12739622.0000,  ...,
          4193501.0000,        0.0000, 26862400.0000],
        [       0.0000,  1913917.0000, 10243700.0000,  ...,
          4573404.0000,        0.0000, 22720538.0000]],
       grad_fn=<ReluBackward0>)

可以看到,当激活函数改为 ReLU 之后,如果我们仍然采用 Xavier 初始化,那么每一层的输出值标准差将逐渐增大:由最初第 0 层的 $1$ 左右逐渐增大到最终第 99 层的 $764$ 万。这并不是我们所希望的。

针对这一问题,在 2015 年,何凯明等人在一篇论文中提出了解决方法:

参考文献Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification

方差一致性:保持数据尺度维持在恰当范围,通常方差为 $1$。

激活函数:ReLU 及其变种。

在论文中,我们同样遵循方差一致性原则,即使得输出层方差为 $1$。论文中针对 ReLU 激活函数,通过公式推导,得到权值 $W$ 的方差为:

\[\mathrm{Var}(W) = \dfrac{2}{n_i}\]

其中,$n_i$ 为输入层的神经元个数。

进一步地,针对 ReLU 的变种,即在激活函数的负半轴上给予一定斜率的情况下,权值 $W$ 的方差为:

\[\mathrm{Var}(W) = \dfrac{2}{(1+a^2)*n_i}\]

其中,$a$ 为激活函数在负半轴上的斜率。$a=0$ 时激活函数即为原始的 ReLU。

由此得到权值 $W$ 的标准差为:

\[\mathrm{Std}(W) = \sqrt{\dfrac{2}{(1+a^2)*n_i}}\]

下面我们根据这一公式进行权值初始化,并观察各网络层的输出。我们可以采用手动计算实现 Kaiming 初始化方法:

1
2
3
4
def initialize(self):
    for m in self.modules():
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))

或者直接使用 PyTorch 中内置的 Kaiming 初始化方法,两者的结果是一致的:

1
2
3
4
def initialize(self):
    for m in self.modules():
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight.data)

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
layer:0, std:0.826629638671875
layer:1, std:0.8786815404891968
layer:2, std:0.9134422540664673
layer:3, std:0.8892471194267273
layer:4, std:0.834428071975708
layer:5, std:0.874537467956543
layer:6, std:0.7926971316337585
layer:7, std:0.7806458473205566
layer:8, std:0.8684563636779785
layer:9, std:0.9434137344360352
layer:10, std:0.964215874671936
layer:11, std:0.8896796107292175
layer:12, std:0.8287257552146912
layer:13, std:0.8519769906997681
layer:14, std:0.8354345560073853
layer:15, std:0.802306056022644
layer:16, std:0.8613607287406921
layer:17, std:0.7583686709403992
layer:18, std:0.8120225071907043
layer:19, std:0.791111171245575
layer:20, std:0.7164372801780701
layer:21, std:0.778393030166626
layer:22, std:0.8672043085098267
layer:23, std:0.874812662601471
layer:24, std:0.9020991325378418
layer:25, std:0.8585715889930725
layer:26, std:0.7824353575706482
layer:27, std:0.7968912720680237
layer:28, std:0.8984369039535522
layer:29, std:0.8704465627670288
layer:30, std:0.9860473275184631
layer:31, std:0.9080777168273926
layer:32, std:0.9140636920928955
layer:33, std:1.009956955909729
layer:34, std:0.9909380674362183
layer:35, std:1.0253208875656128
layer:36, std:0.849043607711792
layer:37, std:0.703953742980957
layer:38, std:0.7186155319213867
layer:39, std:0.7250635027885437
layer:40, std:0.7030817270278931
layer:41, std:0.6325559020042419
layer:42, std:0.6623690724372864
layer:43, std:0.6960875988006592
layer:44, std:0.7140733003616333
layer:45, std:0.632905125617981
layer:46, std:0.6458898186683655
layer:47, std:0.7354375720024109
layer:48, std:0.6710687279701233
layer:49, std:0.6939153671264648
layer:50, std:0.6889258027076721
layer:51, std:0.6331773996353149
layer:52, std:0.6029313206672668
layer:53, std:0.6145528554916382
layer:54, std:0.6636686325073242
layer:55, std:0.7440094947814941
layer:56, std:0.7972175478935242
layer:57, std:0.7606149911880493
layer:58, std:0.696868360042572
layer:59, std:0.7306802272796631
layer:60, std:0.6875627636909485
layer:61, std:0.7171440720558167
layer:62, std:0.7646605372428894
layer:63, std:0.7965086698532104
layer:64, std:0.8833740949630737
layer:65, std:0.8592952489852905
layer:66, std:0.8092936873435974
layer:67, std:0.806481122970581
layer:68, std:0.6792410612106323
layer:69, std:0.6583346128463745
layer:70, std:0.5702278017997742
layer:71, std:0.5084435939788818
layer:72, std:0.4869326055049896
layer:73, std:0.46350404620170593
layer:74, std:0.4796811640262604
layer:75, std:0.47372108697891235
layer:76, std:0.45414549112319946
layer:77, std:0.4971912205219269
layer:78, std:0.492794930934906
layer:79, std:0.4422350823879242
layer:80, std:0.4802998900413513
layer:81, std:0.5579248666763306
layer:82, std:0.5283755660057068
layer:83, std:0.5451980829238892
layer:84, std:0.6203726530075073
layer:85, std:0.6571893095970154
layer:86, std:0.703682005405426
layer:87, std:0.7321067452430725
layer:88, std:0.6924356818199158
layer:89, std:0.6652532815933228
layer:90, std:0.6728308796882629
layer:91, std:0.6606621742248535
layer:92, std:0.6094604730606079
layer:93, std:0.6019102334976196
layer:94, std:0.595421552658081
layer:95, std:0.6624555587768555
layer:96, std:0.6377885341644287
layer:97, std:0.6079285740852356
layer:98, std:0.6579315066337585
layer:99, std:0.6668476462364197
tensor([[0.0000, 1.3437, 0.0000,  ..., 0.0000, 0.6444, 1.1867],
        [0.0000, 0.9757, 0.0000,  ..., 0.0000, 0.4645, 0.8594],
        [0.0000, 1.0023, 0.0000,  ..., 0.0000, 0.5148, 0.9196],
        ...,
        [0.0000, 1.2873, 0.0000,  ..., 0.0000, 0.6454, 1.1411],
        [0.0000, 1.3589, 0.0000,  ..., 0.0000, 0.6749, 1.2438],
        [0.0000, 1.1807, 0.0000,  ..., 0.0000, 0.5668, 1.0600]],
       grad_fn=<ReluBackward0>)

可以看到,现在每个网络层输出值的标准差都能维持在一个相同的尺度上,不会过大或者过小,并且输出值也基本都在正常范围内。

4. 常用的权值始化方法

通过前面的例子,我们对于权值的初始化方法有了清晰的认识。我们知道,不适当的权值初始化方法会引起网络层的输出值过大或者过小,从而引发梯度的消失或者爆炸,最终导致我们的模型无法正常训练。为了避免这种现象的发生,我们必须控制各网络层输出值的尺度范围。根据公式的推断过程我们知道,我们必须使每个网络层的输出值的方差尽量在 $1$ 附近,即遵循方差一致性原则,使得输出值方差不会过大或者过小。

下面我们来学习 PyTorch 中提供的 10 种常用的权值初始化方法:

  1. Xavier 均匀分布
  2. Xavier 正态分布
  3. Kaiming 均匀分布
  4. Kaiming 正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化

这里可以分为 4 大类:Xavier 初始化、Kaiming 初始化、三种常用分布初始化,以及三种特殊的矩阵初始化。PyTorch 中提供了这些方法,方便我们进行权值初始化。那么,在进行权值初始化的时候,我们应该选择哪种初始化方法呢?这需要具体问题具体分析,但是无论我们使用哪种初始化方法,都需要遵循方差一致性原则,即每层输出值的方差不能太大或者太小,尽量保持在 $1$ 附近。

下面我们来学习一个特殊的函数 calculate_gain,它被用于计算激活函数的方差变化尺度。

nn.init.calculate_gain

功能:计算激活函数的 方差变化尺度

1
nn.init.calculate_gain(nonlinearity, param=None)

主要参数

  • nonlinearity:激活函数名称,例如:tanh、Sigmoid、ReLU 等等。
  • param:激活函数的参数,例如:Leaky ReLU 中的 negative_slop

实际上,该函数计算的是输入数据的方差除以输出数据的方差,即方差变化的比例。

代码示例

1
2
3
4
5
6
7
8
x = torch.randn(10000)  # 通过标准正态分布创建 10000 个数据点
out = torch.tanh(x)  # 将数据输入 tanh 函数

gain = x.std() / out.std()  # 手动计算激活函数的标准差变化尺度
print('gain:{}'.format(gain))

tanh_gain = nn.init.calculate_gain('tanh')  # calculate_gain 计算的激活函数的标准差增益
print('tanh_gain in PyTorch:', tanh_gain)

输出结果:

1
2
gain:1.5982500314712524
tanh_gain in PyTorch: 1.6666666666666667

可以看到,tanh 激活函数的标准差增益在 $1.6$ 左右,也就是说,对于均值为 $0$,标准差为 $1$ 的数据,经过 tanh 函数之后,数据的标准差会减小 $1.6$ 倍左右。

5. 总结

本节课中,我们学习了权值初始化方法的准则 —— 方差一致性原则,以及 Xavier 和 Kaiming 权值初始化方法。在下节课中,我们将学习损失函数。

下节内容:损失函数 (一)

知识共享许可协议本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 欢迎转载,并请注明来自:YEY 的博客 同时保持文章内容的完整和以上声明信息!