PyTorch 10:模型创建步骤与 nn.Module

PyTorch 学习笔记

Posted by YEY on December 14, 2020

Lecture 10 模型创建步骤与 nn.Module

前几节课中,我们学习了 PyTorch 的数据模块,并了解了 PyTorch 如何从硬盘中读取数据,然后对数据进行预处理、数据增强,最后转换为张量的形式输入到我们的模型中。在深度模型中,会对张量进行一系列复杂的数学运算,最终得到用于分类、分割、目标检测等任务的输入。本节课中,我们将学习 PyTorch 中模型的创建以及 nn.Module 的相关概念。

1. 网络模型的创建步骤

在学习创建模型之前,我们先回顾一下之前提到的机器学习模型训练的 5 个步骤:

我们已经在前几节课中完成了对数据模块的学习,接下来我们开始学习模型模块。

回顾一下之前在人民币分类的例子中我们使用过的 LeNet 网络:

LeNet 模型结构图

可以看到,LeNet 网络由 7 个层构成:卷积层 1、池化层 1、卷积层 2、池化层 2,以及 3 个全连接层。在创建 LeNet 时,需要先构建这些子模块,在构建完成这 7 个子网络层后,我们会采用一定的顺序对其进行连接。最后,将它们包装起来就得到我们的 LeNet 网络。

在 PyTorch 中,LeNet 是一个 Module 的概念,而它的子网络层也是一个 Module 的概念,它们都属于 nn.Module 类。所以,一个 nn.Module (例如:LeNet) 可以包含很多个子 Module (例如:卷积层、池化层等)。

下面我们从计算图的角度来观察模型的创建过程:

计算图中有两个主要的概念:结点和边。其中,结点代表张量 (数据),边代表运算。LeNet 整体上可以视为一组张量运算:它接收一个 $32\times 32\times 3$ 的张量,经过一系列复杂运算之后,输出一个长度为 $10$ 的向量作为分类概率。而在 LeNet 内部,则由一系列子网络层构成,例如:卷积层 1 对一个 $32\times 32\times 3$ 的张量进行卷积操作得到一个 $28\times 28\times 6$ 的张量,并将其作为下一层子网络的输入,经过这种不断的前向传播,最终计算得到输出概率。在深度学习中,该过程被称为 前向传播

我们从网络结构和计算图的角度分析了 LeNet 网络模型,并且知道了构建模型的两个要素:构建子模块和拼接子模块。

接下来,我们还是通过之前人民币二分类的例子来学习如何构建模型。

构建模型

1
2
3
# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()

LeNet 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class LeNet(nn.Module):
    # 构建子模块
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    # 拼接子模块
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, 0, 0.1)
                m.bias.data.zero_()

2. nn.Module

在模型模块中,我们有一个非常重要的概念 —— nn.Module。我们所有的模型和网络层都是继承自 nn.Module 这个类的,所以我们有必要了解它。在学习 nn.Module 之前,我们先来看一下与其相关的几个模块:

首先是 torch.nn,它是 PyTorch 的一个神经网络模块,其中又有很多子模块,这里我们需要了解其中的 4 个模块:nn.Parameternn.Modulenn.functionalnn.init。本节课我们先重点关注 nn.Module

nn.Module

nn.Module 中有 8 个重要的属性,用于管理整个模型:

1
2
3
4
5
6
7
8
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()

主要属性

  • parameters:存储管理 nn.Parameter 类。
  • modules:存储管理 nn.Module 类。
  • buffers:存储管理缓冲属性,如 BN 层中的 running_mean
  • ***_hooks:存储管理钩子函数。

这里,我们重点关注其中的两个属性:parametersmodules

nn.Module 的属性构建机制

module 类里面进行属性赋值时会先被 __setattr__ 函数拦截,该函数对即将赋值的数据类型进行判断:如果赋值是 nn.Parameter 类,则将其存入 parameters 字典中进行管理;如果赋值是 nn.Module 类,则将其存入 modules 字典中进行管理。

nn.Module 总结

  • 一个 module 可以包含多个子 module。
    • 例如:在 LeNet 这个 module 下会包含一些卷积层、池化层等子 module。
  • 一个 module 相当于一个运算,必须实现 forward() 函数。
    • 从计算图的角度来看,一个 module 接收一个张量,经过一系列复杂运算,输出概率或者其他数据。因此,我们需要在其中实现一个前向传播的函数。
  • 每个 module 都有 8 个 有序字典 (OrderedDict) 管理它的属性。
    • 这里,最常用的是 parameters 字典和 modules 字典。

3. 总结

本节课中,我们学习了 nn.Module 的概念以及模型创建的两个要素。下节课中,我们将学习容器 Containers 以及 AlexNet 的搭建。

下节内容:模型容器与 AlexNet 构建

知识共享许可协议本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 欢迎转载,并请注明来自:YEY 的博客 同时保持文章内容的完整和以上声明信息!