深度学习入门之Pytorch 数据增强的实现


Posted in Python onFebruary 26, 2020

数据增强

卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。

2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。

常用的数据增强方法

常用的数据增强方法如下:
1.对图片进行一定比例缩放
2.对图片进行随机位置的截取
3.对图片进行随机的水平和竖直翻转
4.对图片进行随机角度的旋转
5.对图片进行亮度、对比度和颜色的随机变化

这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法。

import sys
sys.path.append('..')

from PIL import Image
from torchvision import transforms as tfs

# 读入一张图片
im = Image.open('./cat.png')
im

深度学习入门之Pytorch 数据增强的实现

随机比例放缩

随机比例缩放主要使用的是 torchvision.transforms.Resize() 这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看文档

# 比例缩放
print('before scale, shape: {}'.format(im.size))
new_im = tfs.Resize((100, 200))(im)
print('after scale, shape: {}'.format(new_im.size))
new_im

深度学习入门之Pytorch 数据增强的实现

随机位置截取

随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop(),同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取

# 随机裁剪出 100 x 100 的区域
random_im1 = tfs.RandomCrop(100)(im)
random_im1

深度学习入门之Pytorch 数据增强的实现

# 中心裁剪出 100 x 100 的区域
center_im = tfs.CenterCrop(100)(im)
center_im

深度学习入门之Pytorch 数据增强的实现

随机的水平和竖直方向翻转

对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip()torchvision.transforms.RandomVerticalFlip()

# 随机水平翻转
h_filp = tfs.RandomHorizontalFlip()(im)
h_filp

深度学习入门之Pytorch 数据增强的实现

# 随机竖直翻转
v_flip = tfs.RandomVerticalFlip()(im)
v_flip

深度学习入门之Pytorch 数据增强的实现

随机角度旋转

一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation() 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转

rot_im = tfs.RandomRotation(45)(im)
rot_im

深度学习入门之Pytorch 数据增强的实现

亮度、对比度和颜色的变化

除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色

# 亮度
bright_im = tfs.ColorJitter(brightness=1)(im) # 随机从 0 ~ 2 之间亮度变化,1 表示原图
bright_im

深度学习入门之Pytorch 数据增强的实现

# 对比度
contrast_im = tfs.ColorJitter(contrast=1)(im) # 随机从 0 ~ 2 之间对比度变化,1 表示原图
contrast_im

深度学习入门之Pytorch 数据增强的实现

# 颜色
color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化
color_im

深度学习入门之Pytorch 数据增强的实现

上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 torchvision.transforms.Compose(),下面我们举个例子

im_aug = tfs.Compose([
  tfs.Resize(120),
  tfs.RandomHorizontalFlip(),
  tfs.RandomCrop(96),
  tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
  for j in range(ncols):
    figs[i][j].imshow(im_aug(im))
    figs[i][j].axes.get_xaxis().set_visible(False)
    figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

深度学习入门之Pytorch 数据增强的实现

可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据
下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用 ResNet 进行训练

使用数据增强

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs
# 使用数据增强
def train_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(120),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(96),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

def test_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

不使用数据增强

# 不使用数据增强
def data_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。

而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。

以上就是深度学习入门之Pytorch 数据增强的实现的详细内容,更多关于Pytorch 数据增强的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python最长公共子串算法实例
Mar 07 Python
在Python的Flask框架中实现单元测试的教程
Apr 20 Python
Python win32com 操作Exce的l简单方法(必看)
May 25 Python
在python win系统下 打开TXT文件的实例
Apr 29 Python
Python 带有参数的装饰器实例代码详解
Dec 06 Python
Python面向对象程序设计类变量与成员变量、类方法与成员方法用法分析
Apr 12 Python
Python Django基础二之URL路由系统
Jul 18 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 Python
通过Turtle库在Python中绘制一个鼠年福鼠
Feb 03 Python
Python Handler处理器和自定义Opener原理详解
Mar 05 Python
解决python父线程关闭后子线程不关闭问题
Apr 25 Python
python进行二次方程式计算的实例讲解
Dec 06 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 #Python
python 回溯法模板详解
Feb 26 #Python
python实现信号时域统计特征提取代码
Feb 26 #Python
Python 基于FIR实现Hilbert滤波器求信号包络详解
Feb 26 #Python
python实现逆滤波与维纳滤波示例
Feb 26 #Python
Python全面分析系统的时域特性和频率域特性
Feb 26 #Python
解决pycharm每次打开项目都需要配置解释器和安装库问题
Feb 26 #Python
You might like
基于OpenCV的PHP图像人脸识别技术
2009/10/11 PHP
PHP中几个常用的魔术常量
2012/02/23 PHP
PHP中空字符串介绍0、null、empty和false之间的关系
2012/09/25 PHP
分享常见的几种页面静态化的方法
2015/01/08 PHP
PHP ajax 异步执行不等待执行结果的处理方法
2015/05/27 PHP
Laravel框架实现redis集群的方法分析
2017/09/14 PHP
ThinkPHP3.2.3框架邮件发送功能图文实例详解
2019/04/23 PHP
php设计模式之职责链模式实例分析【星际争霸游戏案例】
2020/03/27 PHP
JSON 学习之完全手册 图文
2007/05/29 Javascript
Web开发之JavaScript
2012/03/29 Javascript
js 编码转换 gb2312 和 utf8 互转的2种方法
2013/08/07 Javascript
HTML页面滚动时获取离页面顶部的距离2种实现方法
2013/09/05 Javascript
js快速排序的实现代码
2013/12/08 Javascript
Javascript节点关系实例分析
2015/05/15 Javascript
使用jquery动态加载Js文件和Css文件
2015/10/24 Javascript
Jquery效果大全之制作电脑健康体检得分特效附源码下载
2015/11/02 Javascript
BootstrapTable+KnockoutJS自定义T4模板快速生成增删改查页面
2016/08/01 Javascript
JavaScript使用math.js进行精确计算操作示例
2018/06/19 Javascript
原生JS实现旋转轮播图+文字内容切换效果【附源码】
2018/09/29 Javascript
使用Turtle画正螺旋线的方法
2017/09/22 Python
解决python3中解压zip文件是文件名乱码的问题
2018/03/22 Python
对python 匹配字符串开头和结尾的方法详解
2018/10/27 Python
numpy下的flatten()函数用法详解
2019/05/27 Python
python中enumerate() 与zip()函数的使用比较实例分析
2019/09/03 Python
python中删除某个元素的方法解析
2019/11/05 Python
python GUI库图形界面开发之PyQt5选项卡控件QTabWidget详细使用方法与实例
2020/03/01 Python
在ipython notebook中使用argparse方式
2020/04/20 Python
为什么说python更适合树莓派编程
2020/07/20 Python
雅诗兰黛香港官网:Estee Lauder香港
2017/09/26 全球购物
美国领先的在线邮轮旅游公司:CruiseDirect
2018/06/07 全球购物
银河香水:Galaxy Perfume
2019/03/25 全球购物
大专生自我鉴定范文
2013/10/01 职场文书
大学生军训自我鉴定
2014/02/12 职场文书
工厂门卫岗位职责范本
2014/04/04 职场文书
python 如何将两个实数矩阵合并为一个复数矩阵
2021/05/19 Python
Golang 实现WebSockets
2022/04/24 Golang