PyTorch 04:计算图与动态图机制

PyTorch 学习笔记

Posted by YEY on December 8, 2020

Lecture 04 计算图与动态图机制

本节课分为两部分:计算图和 PyTorch 中的动态图机制。在之前的课程中,我们学习了张量的创建和操作,而深度学习就是对张量进行一系列操作,随着操作种类和数量的增多,可能会导致各种意想不到的问题。例如:多个操作之间是并行还是顺序执行;如何协同不同底层设备;如何避免各种冗余操作等等。这些问题都会影响到我们的运算效率,甚至会引入一些不必要的 bug,计算图正是为解决这些问题而生的。

1. 计算图

计算图 (Computational Graph) 是用来 描述运算有向无环图

计算图有两个主要元素:结点 (Node)边 (Edge)

  • 结点表示 数据,例如:向量、矩阵、张量。
  • 边表示 运算,例如:加、减、乘、除、卷积等。

例子

用计算图表示:$y=(x+w) * (w+1)$

我们先将运算过程拆分为:

  • $a = x+w$
  • $b=w+1$
  • $y=a*b$

搭建计算图:

将 $x=2,w=1$ 代入进行计算:

采用计算图来描述运算过程的好处不只是让运算更加简洁,更重要的一点是,它使得梯度求导更加方便。接下来,我们来看一下 $y$ 对 $w$ 求导的过程。

计算图与梯度求导:

\[y=(x+w) * (w+1)\]
  • $a = x+w$
  • $b=w+1$
  • $y=a*b$
\[\begin{align} \dfrac{\partial y}{\partial w} &= \dfrac{\partial y}{\partial a} \dfrac{\partial a}{\partial w} + \dfrac{\partial y}{\partial b} \dfrac{\partial b}{\partial w} \\[2ex] &= b * 1+ a * 1 \\[2ex] &= (w+1) + (x + w) \\[2ex] &= 2*w + x + 1 \\[2ex] &= 2*1 + 2 + 1 = 5 \end{align}\]

在之前课程中,我们提到 Tensor 中有一个 is_leaf 属性,它用于指示张量是否是叶子结点。

叶子结点:用户创建的结点称为叶子结点,例如上面的 $x$ 和 $w$。它是整个计算图的根基,可以看到,在前向传播中的 $a$、$b$ 和 $y$ 都需要根据叶子结点 $x$ 和 $w$ 进行计算。同样,在反向传播中,所有的梯度计算也都依赖于叶子结点。

为什么要设置叶子结点这一概念呢?

主要是为了节省内存,因为在反向传播结束之后,非叶子结点的梯度将被丢弃。

除了叶子结点之外,Tensor 中还有一个重要的概念就是 梯度函数 grad_fn,它记录了创建该张量时所用的方法 (函数),在反向传播时需要用到该方法以 确定对应的求导法则

例如上面的 $a$ 和 $b$ 都是通过加法创建的,而 $y$ 是通过乘法创建的,所以我们有:

  • y.grad_fn = <MulBackward0>
  • a.grad_fn = <AddBackward0>
  • b.grad_fn = <AddBackward0>

Python 代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch

# 叶子结点
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

# 非叶子结点
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

# PyTorch 中,为了节省内存,在计算完成后会丢弃非叶子结点的梯度值,即为 None;
# 如果希望保存非叶子结点梯度值,需要在反向传播之前使用 .retain_grad() 方法。
a.retain_grad()
b.retain_grad()
y.retain_grad()

# 反向传播
y.backward()
print(w.grad)

# 查看叶子结点
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)

# 查看梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

# 查看 grad_fn
print("grad_fn:\n", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)

输出结果:

1
2
3
4
5
6
7
tensor([5.])
is_leaf:
 True True False False False
gradient:
 tensor([5.]) tensor([2.]) tensor([2.]) tensor([3.]) tensor([1.])
grad_fn:
 None None <AddBackward0 object at 0x1146d1518> <AddBackward0 object at 0x1146d1550> <MulBackward0 object at 0x1146d15c0>

2. 动态图

根据计算图搭建方式, 可将计算图分为 动态图 (Dynamic Graphs)静态图 (Static Graphs)。PyTorch 采用的是动态图机制,而 TensorFlow 采用的是静态图机制。

动态图 vs 静态图

3. 总结

本节课介绍了 PyTorch 最大的特性 —— 动态图机制,动态图机制是 PyTorch 与 TensorFlow 最大的区别,我们首先介绍了计算图的概念,并通过演示动态图与静态图的搭建过程来理解动态图与静态图的差异。

下节内容:autograd 与逻辑回归

知识共享许可协议本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 欢迎转载,并请注明来自:YEY 的博客 同时保持文章内容的完整和以上声明信息!