PyTorch 03:张量操作与线性回归

PyTorch 学习笔记

Posted by YEY on December 7, 2020

Lecture 03 张量操作与线性回归

1. 张量操作

1.1 张量拼接与切分

torch.cat()

功能:将张量按维度 dim 进行拼接。

1
2
3
4
5
torch.cat(
    tensors,
    dim=0,
    out=None
)

主要参数

  • tensors:张量序列。
  • dim:要拼接的维度。

torch.stack()

功能:在 新创建的维度 dim 上进行拼接。

1
2
3
4
5
torch.stack(
    tensors,
    dim=0,
    out=None
)

主要参数

  • tensors:张量序列。
  • dim:要拼接的维度。

torch.chunk()

功能:将张量按维度 dim 进行 平均切分

1
2
3
4
5
torch.chunk(
    input,
    chunks,
    dim=0
)

主要参数

  • input:要切分的张量。
  • chunks:要切分的份数。
  • dim:要切分的维度。

返回值:张量列表。

注意事项:若不能整除,最后一份张量将小于其他张量。

torch.split()

功能:将张量按维度 dim 进行切分。

1
2
3
4
5
torch.split(
    tensor,
    split_size_or_sections,
    dim=0
)

主要参数

  • tensor:要切分的张量。
  • split_size_or_sections:为 int 时,表示每一份的长度;为 list 时,按 list 元素切分。
  • dim:要切分的维度。

返回值:张量列表。

1.2 张量索引

torch.index_select()

功能:在维度 dim 上,按 index 索引数据。

1
2
3
4
5
6
torch.index_select(
    input,
    dim,
    index,
    out=None
)

主要参数

  • input:要索引的张量。
  • dim:要索引的维度。
  • index:要索引数据的序号,注意这里 index 中的数据类型必须是 torch.long

返回值:依 index 索引数据拼接的张量。

torch.masked_select()

功能:按 mask 中的 True 进行索引。

1
2
3
4
5
torch.masked_select(
    input,
    mask,
    out=None
)

主要参数

  • input:要索引的张量。
  • mask:与 input 同形状的布尔类型张量。

返回值一维 张量。

1.3 张量变换

torch.reshape()

功能:变换张量形状。

1
2
3
4
torch.reshape(
    input,
    shape
)

主要参数

  • input:要变换的张量。
  • shape:新张量的形状,当我们不需要关心某个维度时,可以将其设为 -1,它将通过对其他维度的计算自动得出。

注意事项:当张量在内存中是连续时,新张量与 input 共享数据内存。

torch.transpose()

功能:交换张量的两个维度,常用于图像的预处理。

1
2
3
4
5
torch.transpose(
    input,
    dim0,
    dim1
)

主要参数

  • input:要变换的张量。
  • dim0:要交换的维度。
  • dim1:要交换的维度。

torch.t()

功能:2 维张量转置,对矩阵而言,等价于 torch.transpose(input, 0, 1)

1
torch.t(input)

torch.squeeze()

功能压缩 长度为 1 的维度(轴)。

1
2
3
4
5
torch.squeeze(
    input,
    dim=None,
    out=None
)

主要参数

  • dim:若为 None,移除所有长度为 1 的轴;若指定维度,当且仅当该轴长度为 1 时,可以被移除。

torch.unsqueeze()

功能:依据 dim 扩展 维度。

1
2
3
4
5
torch.usqueeze(
    input,
    dim,
    out=None
)

主要参数

  • dim:扩展的维度。

2. 张量数学运算

2.1 加减乘除

1
2
3
4
5
6
torch.add()
torch.addcdiv()
torch.addcmul()
torch.sub()
torch.div()
torch.mul()

torch.add()

功能:逐元素计算 input + alpha × other

1
2
3
4
5
6
torch.add(
    input,
    alpha=1,
    other,
    out=None
)

主要参数

  • input:第一个张量。
  • alpha:乘项因子。
  • other:第二个张量。

torch.addcmul()

功能:逐元素计算

\[\texttt{out}_i = \texttt{input}_i + \texttt{value} \times \texttt{tensor1}_i \times \texttt{tensor2}_i\]
1
2
3
4
5
6
7
torch.addcmul(
    input,
    value=1,
    tensor1,
    tensor2,
    out=None
)

torch.addcdiv()

功能:逐元素计算

\[\texttt{out}_i = \texttt{input}_i + \texttt{value} \times \dfrac{\texttt{tensor1}_i}{\texttt{tensor2}_i}\]
1
2
3
4
5
6
7
torch.addcdiv(
    input,
    value=1,
    tensor1,
    tensor2,
    out=None
)

2.2 对数、指数、幂函数

1
2
3
4
5
torch.log(input, out=None)
torch.log10(input, out=None)
torch.log2(input, out=None)
torch.exp(input, out=None)
torch.pow()

2.3 三角函数

1
2
3
4
5
6
7
torch.abs(input, out=None)
torch.acos(input, out=None)
torch.cosh(input, out=None)
torch.cos(input, out=None)
torch.asin(input, out=None)
torch.atan(input, out=None)
torch.atan2(input, other, out=None)

3. 线性回归

线性回归 是分析一个变量与另外一(多)个变量之间关系的方法:

\[y=wx+b\]

其中,$y$ 是 因变量,$x$ 是 自变量,二者之间关系为 线性

分析:求解线性组合系数 $w$ 和 $b$。

求解步骤

  1. 确定模型

    \[\text{Model:}\quad y=wx+b\]
  2. 选择损失函数

    \[\text{MSE:}\quad \dfrac{1}{n}\sum_{i=1}^{n}(y_i - \hat y_i)^2\]
  3. 求解梯度并更新 $w$ 和 $b$

    \[\begin{align} w &= w - \mathrm{LR} * w.\texttt{grad} \\[2ex] b &= b - \mathrm{LR} * w.\texttt{grad} \end{align}\]

    其中,$\mathrm{LR}$ 是 学习率 (learning rate)

代码示例


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import matplotlib.pyplot as plt

torch.manual_seed(10)

# 学习率
lr = 0.1

# 创建训练数据
x = torch.rand(20, 1) * 10  # x data (tensor), shape=(20, 1)
y = 2*x + (5 + torch.randn(20, 1))  # y data (tensor), shape=(20, 1)

# 初始化线性回归参数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)

for iteration in range(1000):

    # 向前传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()

    # 反向传播
    loss.backward()

    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

    # 绘图
    if iteration % 20 == 0:

        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r-', lw=5)
        plt.text(2, 20, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.xlim(1.5, 10)
        plt.ylim(8, 28)
        plt.title("Iteration: {}\nw: {} b: {}".format(iteration, w.data.numpy(), b.data.numpy()))
        plt.pause(0.5)

        # 当 loss < 1 时,停止迭代更新
        if loss.data.numpy() < 1:
            break

4. 总结

本节课介绍了张量的基本操作,例如:张量的拼接、切分、索引和变换。同时,我们还学习了张量的数学运算,并基于所学习的知识,实现线性回归模型的训练,以加深知识点的认识。

下节内容:计算图与动态图机制

知识共享许可协议本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 欢迎转载,并请注明来自:YEY 的博客 同时保持文章内容的完整和以上声明信息!