PyTorch 08:transforms 数据增强:裁剪、翻转、旋转

PyTorch 学习笔记

Posted by YEY on December 12, 2020

Lecture 08 transforms 数据增强:裁剪、翻转、旋转

在之前课程中,我们已经熟悉了 PyTorch 中 transforms 的运行机制,它提供了大量的图像增强方法,例如裁剪、旋转、翻转等等,以及可以自定义实现增强方法。本节课中,我们将进一步学习 transforms 中的图像增强方法。

1. 数据增强

数据增强 (Data Augmentation) 又称为数据增广、数据扩增,它是对 训练集 进行变换,使训练集更丰富,从而让模型更具 泛化能力

例子

例子

2. transforms 裁剪

transforms.CenterCrop

功能:从图像中心裁剪图片。

1
transforms.CenterCrop(size)

主要参数

  • size:所需裁剪图片尺寸。

代码示例

我们有一个 224 $\times$ 224 的图片,我们将其从中心裁剪为 196 $\times$ 196 的图片。

1
2
3
4
5
6
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # CenterCrop,如果 size 大于原始尺寸,多余部分将用黑色 (即像素值为 0) 填充
    transforms.CenterCrop(196)
])

transforms.RandomCrop

功能:从图片中随机裁剪出尺寸为 size 的图片。

1
2
3
4
5
6
7
transforms.RandomCrop(
    size,
    padding=None,
    pad_if_needed=False,
    fill=0,
    padding_mode='constant'
)

主要参数

  • size:所需裁剪图片尺寸。
  • padding:设置填充大小。
    • 当为 a 时,上下左右均填充 a 个像素。
    • 当为 (a, b) 时, 上下填充 b 个像素, 左右填充 a 个像素。
    • 当为 (a, b, c, d) 时,左、上、右、下分别填充 abcd 个像素。
  • pad_if_need:若图像小于设定 size,则填充,此时该项需要设置为 True
  • padding_mode:填充模式,有 4 种模式:
    • constant:像素值由 fill 设定。
    • edge:像素值由图像边缘像素决定。
    • reflect:镜像填充,最后一个像素不镜像,例如 [1, 2, 3, 4] $\to$ [3, 2, 1, 2, 3, 4, 3, 2]
    • symmetric:镜像填充,最后一个像素镜像,例如 [1, 2, 3, 4] $\to$ [2, 1, 1, 2, 3, 4, 4, 3]
  • fillpadding_mode = 'constant' 时,设置填充的像素值。

transforms.RandomResizedCrop

功能:随机大小、长宽比裁剪图片。

1
2
3
4
5
6
RandomResizedCrop(
    size,
    scale=(0.08, 1.0),
    ratio=(3/4, 4/3),
    interpolation
)

主要参数

  • size:所需裁剪图片尺寸。
  • scale:随机裁剪面积比例,默认 (0.08, 1)
  • ratio:随机长宽比,默认 (3/4, 4/3)
  • interpolation:插值方法。
    • PIL.Image.NEAREST
    • PIL.Image.BILINEAR
    • PIL.Image.BICUBIC

transforms.FiveCrop

功能:在图像的上下左右以及中心裁剪出尺寸为 size 的 5 张图片。

1
transforms.FiveCrop(size)

主要参数

  • size:所需裁剪图片尺寸。

代码示例

1
2
3
4
5
6
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    # 注意:由于生成了 5 张图片,返回的是一个元组,我们需要将其转换为 PIL Image 或者 ndarray 的形式。
    transforms.FiveCrop(112),
    transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops]))
])

transforms.TenCrop

功能:在图像的上下左右以及中心裁剪出尺寸为 size 的 5 张图片,并对这 5 张图片进行水平或者垂直镜像获得 10 张图片。

1
2
3
4
transforms.TenCrop(
    size,
    vertical_flip=False
)

主要参数

  • size:所需裁剪图片尺寸。
  • vertical_flip:是否垂直翻转。

代码示例

1
2
3
4
5
6
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    # 注意:由于生成了 10 张图片,返回的是一个元组,我们需要将其转换为 PIL Image 或者 ndarray 的形式。
    transforms.TenCrop(112, vertical_flip=False),
    transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),
])

3. transforms 翻转、旋转

3.1 transforms 翻转

transforms.RandomHorizontalFlip

功能:依概率水平(左右)翻转图片。

1
transforms.RandomHorizontalFlip(p=0.5)

主要参数

  • p:翻转概率。

transforms.RandomVerticalFlip

功能:依概率垂直(上下)翻转图片。

1
transforms.RandomVerticalFlip(p=0.5)

主要参数

  • p:翻转概率。

3.2 transforms 旋转

transforms.RandomRotation

功能:随机旋转图片。

1
2
3
4
5
6
RandomRotation(
    degrees,
    resample=False,
    expand=False,
    center=None
)

主要参数

  • degrees:旋转角度。
    • 当为 a 时,在 (-a, a) 之间随机选择旋转角度。
    • 当为 (a, b) 时,在 (a, b) 之间随机选择旋转角度。
  • resample:重采样方法。
  • expand:是否扩大图片,以保持原图信息。
  • center:旋转点设置,默认中心旋转。

例子

4. 总结

本节课中,我们学习了数据预处理模块 transforms 中的数据增强方法:裁剪、翻转和旋转。在下次课程中 ,我们将会学习 transforms 中的其他数据增强方法。

下节内容:transforms 图像变换、方法操作及自定义方法

知识共享许可协议本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 欢迎转载,并请注明来自:YEY 的博客 同时保持文章内容的完整和以上声明信息!