Lecture 10 模型创建步骤与 nn.Module
前几节课中,我们学习了 PyTorch 的数据模块,并了解了 PyTorch 如何从硬盘中读取数据,然后对数据进行预处理、数据增强,最后转换为张量的形式输入到我们的模型中。在深度模型中,会对张量进行一系列复杂的数学运算,最终得到用于分类、分割、目标检测等任务的输入。本节课中,我们将学习 PyTorch 中模型的创建以及 nn.Module
的相关概念。
1. 网络模型的创建步骤
在学习创建模型之前,我们先回顾一下之前提到的机器学习模型训练的 5 个步骤:
我们已经在前几节课中完成了对数据模块的学习,接下来我们开始学习模型模块。
回顾一下之前在人民币分类的例子中我们使用过的 LeNet 网络:
LeNet 模型结构图:
可以看到,LeNet 网络由 7 个层构成:卷积层 1、池化层 1、卷积层 2、池化层 2,以及 3 个全连接层。在创建 LeNet 时,需要先构建这些子模块,在构建完成这 7 个子网络层后,我们会采用一定的顺序对其进行连接。最后,将它们包装起来就得到我们的 LeNet 网络。
在 PyTorch 中,LeNet 是一个 Module 的概念,而它的子网络层也是一个 Module 的概念,它们都属于 nn.Module
类。所以,一个 nn.Module
(例如:LeNet) 可以包含很多个子 Module (例如:卷积层、池化层等)。
下面我们从计算图的角度来观察模型的创建过程:
计算图中有两个主要的概念:结点和边。其中,结点代表张量 (数据),边代表运算。LeNet 整体上可以视为一组张量运算:它接收一个 $32\times 32\times 3$ 的张量,经过一系列复杂运算之后,输出一个长度为 $10$ 的向量作为分类概率。而在 LeNet 内部,则由一系列子网络层构成,例如:卷积层 1 对一个 $32\times 32\times 3$ 的张量进行卷积操作得到一个 $28\times 28\times 6$ 的张量,并将其作为下一层子网络的输入,经过这种不断的前向传播,最终计算得到输出概率。在深度学习中,该过程被称为 前向传播。
我们从网络结构和计算图的角度分析了 LeNet 网络模型,并且知道了构建模型的两个要素:构建子模块和拼接子模块。
接下来,我们还是通过之前人民币二分类的例子来学习如何构建模型。
构建模型:
1
2
3
# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()
LeNet 类:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class LeNet(nn.Module):
# 构建子模块
def __init__(self, classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
# 拼接子模块
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data, 0, 0.1)
m.bias.data.zero_()
2. nn.Module
在模型模块中,我们有一个非常重要的概念 —— nn.Module
。我们所有的模型和网络层都是继承自 nn.Module
这个类的,所以我们有必要了解它。在学习 nn.Module
之前,我们先来看一下与其相关的几个模块:
首先是 torch.nn
,它是 PyTorch 的一个神经网络模块,其中又有很多子模块,这里我们需要了解其中的 4 个模块:nn.Parameter
、nn.Module
、nn.functional
和 nn.init
。本节课我们先重点关注 nn.Module
。
nn.Module
在 nn.Module
中有 8 个重要的属性,用于管理整个模型:
1
2
3
4
5
6
7
8
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
主要属性:
parameters
:存储管理nn.Parameter
类。modules
:存储管理nn.Module
类。buffers
:存储管理缓冲属性,如 BN 层中的running_mean
。***_hooks
:存储管理钩子函数。
这里,我们重点关注其中的两个属性:parameters
和 modules
。
nn.Module
的属性构建机制:
在 module
类里面进行属性赋值时会先被 __setattr__
函数拦截,该函数对即将赋值的数据类型进行判断:如果赋值是 nn.Parameter
类,则将其存入 parameters
字典中进行管理;如果赋值是 nn.Module
类,则将其存入 modules
字典中进行管理。
nn.Module
总结:
- 一个 module 可以包含多个子 module。
- 例如:在 LeNet 这个 module 下会包含一些卷积层、池化层等子 module。
- 一个 module 相当于一个运算,必须实现
forward()
函数。- 从计算图的角度来看,一个 module 接收一个张量,经过一系列复杂运算,输出概率或者其他数据。因此,我们需要在其中实现一个前向传播的函数。
- 每个 module 都有 8 个 有序字典 (
OrderedDict
) 管理它的属性。- 这里,最常用的是
parameters
字典和modules
字典。
- 这里,最常用的是
3. 总结
本节课中,我们学习了 nn.Module
的概念以及模型创建的两个要素。下节课中,我们将学习容器 Containers 以及 AlexNet 的搭建。
下节内容:模型容器与 AlexNet 构建
本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 欢迎转载,并请注明来自:YEY 的博客 同时保持文章内容的完整和以上声明信息!