Lecture 03 张量操作与线性回归
1. 张量操作
1.1 张量拼接与切分
torch.cat()
功能:将张量按维度 dim
进行拼接。
1
2
3
4
5
torch.cat(
tensors,
dim=0,
out=None
)
主要参数:
tensors
:张量序列。dim
:要拼接的维度。
torch.stack()
功能:在 新创建的维度 dim
上进行拼接。
1
2
3
4
5
torch.stack(
tensors,
dim=0,
out=None
)
主要参数:
tensors
:张量序列。dim
:要拼接的维度。
torch.chunk()
功能:将张量按维度 dim
进行 平均切分。
1
2
3
4
5
torch.chunk(
input,
chunks,
dim=0
)
主要参数:
input
:要切分的张量。chunks
:要切分的份数。dim
:要切分的维度。
返回值:张量列表。
注意事项:若不能整除,最后一份张量将小于其他张量。
torch.split()
功能:将张量按维度 dim
进行切分。
1
2
3
4
5
torch.split(
tensor,
split_size_or_sections,
dim=0
)
主要参数:
tensor
:要切分的张量。split_size_or_sections
:为int
时,表示每一份的长度;为list
时,按list
元素切分。dim
:要切分的维度。
返回值:张量列表。
1.2 张量索引
torch.index_select()
功能:在维度 dim
上,按 index
索引数据。
1
2
3
4
5
6
torch.index_select(
input,
dim,
index,
out=None
)
主要参数:
input
:要索引的张量。dim
:要索引的维度。index
:要索引数据的序号,注意这里index
中的数据类型必须是torch.long
。
返回值:依 index
索引数据拼接的张量。
torch.masked_select()
功能:按 mask
中的 True
进行索引。
1
2
3
4
5
torch.masked_select(
input,
mask,
out=None
)
主要参数:
input
:要索引的张量。mask
:与input
同形状的布尔类型张量。
返回值:一维 张量。
1.3 张量变换
torch.reshape()
功能:变换张量形状。
1
2
3
4
torch.reshape(
input,
shape
)
主要参数:
input
:要变换的张量。shape
:新张量的形状,当我们不需要关心某个维度时,可以将其设为 -1,它将通过对其他维度的计算自动得出。
注意事项:当张量在内存中是连续时,新张量与 input
共享数据内存。
torch.transpose()
功能:交换张量的两个维度,常用于图像的预处理。
1
2
3
4
5
torch.transpose(
input,
dim0,
dim1
)
主要参数:
input
:要变换的张量。dim0
:要交换的维度。dim1
:要交换的维度。
torch.t()
功能:2 维张量转置,对矩阵而言,等价于 torch.transpose(input, 0, 1)
。
1
torch.t(input)
torch.squeeze()
功能:压缩 长度为 1 的维度(轴)。
1
2
3
4
5
torch.squeeze(
input,
dim=None,
out=None
)
主要参数:
dim
:若为None
,移除所有长度为 1 的轴;若指定维度,当且仅当该轴长度为 1 时,可以被移除。
torch.unsqueeze()
功能:依据 dim
扩展 维度。
1
2
3
4
5
torch.usqueeze(
input,
dim,
out=None
)
主要参数:
dim
:扩展的维度。
2. 张量数学运算
2.1 加减乘除
1
2
3
4
5
6
torch.add()
torch.addcdiv()
torch.addcmul()
torch.sub()
torch.div()
torch.mul()
torch.add()
功能:逐元素计算 input + alpha × other
1
2
3
4
5
6
torch.add(
input,
alpha=1,
other,
out=None
)
主要参数:
input
:第一个张量。alpha
:乘项因子。other
:第二个张量。
torch.addcmul()
功能:逐元素计算
\[\texttt{out}_i = \texttt{input}_i + \texttt{value} \times \texttt{tensor1}_i \times \texttt{tensor2}_i\]1
2
3
4
5
6
7
torch.addcmul(
input,
value=1,
tensor1,
tensor2,
out=None
)
torch.addcdiv()
功能:逐元素计算
\[\texttt{out}_i = \texttt{input}_i + \texttt{value} \times \dfrac{\texttt{tensor1}_i}{\texttt{tensor2}_i}\]1
2
3
4
5
6
7
torch.addcdiv(
input,
value=1,
tensor1,
tensor2,
out=None
)
2.2 对数、指数、幂函数
1
2
3
4
5
torch.log(input, out=None)
torch.log10(input, out=None)
torch.log2(input, out=None)
torch.exp(input, out=None)
torch.pow()
2.3 三角函数
1
2
3
4
5
6
7
torch.abs(input, out=None)
torch.acos(input, out=None)
torch.cosh(input, out=None)
torch.cos(input, out=None)
torch.asin(input, out=None)
torch.atan(input, out=None)
torch.atan2(input, other, out=None)
3. 线性回归
线性回归 是分析一个变量与另外一(多)个变量之间关系的方法:
\[y=wx+b\]其中,$y$ 是 因变量,$x$ 是 自变量,二者之间关系为 线性。
分析:求解线性组合系数 $w$ 和 $b$。
求解步骤:
-
确定模型
\[\text{Model:}\quad y=wx+b\] -
选择损失函数
\[\text{MSE:}\quad \dfrac{1}{n}\sum_{i=1}^{n}(y_i - \hat y_i)^2\] -
求解梯度并更新 $w$ 和 $b$
\[\begin{align} w &= w - \mathrm{LR} * w.\texttt{grad} \\[2ex] b &= b - \mathrm{LR} * w.\texttt{grad} \end{align}\]其中,$\mathrm{LR}$ 是 学习率 (learning rate)。
代码示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import matplotlib.pyplot as plt
torch.manual_seed(10)
# 学习率
lr = 0.1
# 创建训练数据
x = torch.rand(20, 1) * 10 # x data (tensor), shape=(20, 1)
y = 2*x + (5 + torch.randn(20, 1)) # y data (tensor), shape=(20, 1)
# 初始化线性回归参数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
for iteration in range(1000):
# 向前传播
wx = torch.mul(w, x)
y_pred = torch.add(wx, b)
# 计算 MSE loss
loss = (0.5 * (y - y_pred) ** 2).mean()
# 反向传播
loss.backward()
# 更新参数
b.data.sub_(lr * b.grad)
w.data.sub_(lr * w.grad)
# 绘图
if iteration % 20 == 0:
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r-', lw=5)
plt.text(2, 20, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.xlim(1.5, 10)
plt.ylim(8, 28)
plt.title("Iteration: {}\nw: {} b: {}".format(iteration, w.data.numpy(), b.data.numpy()))
plt.pause(0.5)
# 当 loss < 1 时,停止迭代更新
if loss.data.numpy() < 1:
break
4. 总结
本节课介绍了张量的基本操作,例如:张量的拼接、切分、索引和变换。同时,我们还学习了张量的数学运算,并基于所学习的知识,实现线性回归模型的训练,以加深知识点的认识。
下节内容:计算图与动态图机制
本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 欢迎转载,并请注明来自:YEY 的博客 同时保持文章内容的完整和以上声明信息!