深度学习入门之Pytorch 数据增强的实现


Posted in Python onFebruary 26, 2020

数据增强

卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。

2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。

常用的数据增强方法

常用的数据增强方法如下:
1.对图片进行一定比例缩放
2.对图片进行随机位置的截取
3.对图片进行随机的水平和竖直翻转
4.对图片进行随机角度的旋转
5.对图片进行亮度、对比度和颜色的随机变化

这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法。

import sys
sys.path.append('..')

from PIL import Image
from torchvision import transforms as tfs

# 读入一张图片
im = Image.open('./cat.png')
im

深度学习入门之Pytorch 数据增强的实现

随机比例放缩

随机比例缩放主要使用的是 torchvision.transforms.Resize() 这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看文档

# 比例缩放
print('before scale, shape: {}'.format(im.size))
new_im = tfs.Resize((100, 200))(im)
print('after scale, shape: {}'.format(new_im.size))
new_im

深度学习入门之Pytorch 数据增强的实现

随机位置截取

随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop(),同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取

# 随机裁剪出 100 x 100 的区域
random_im1 = tfs.RandomCrop(100)(im)
random_im1

深度学习入门之Pytorch 数据增强的实现

# 中心裁剪出 100 x 100 的区域
center_im = tfs.CenterCrop(100)(im)
center_im

深度学习入门之Pytorch 数据增强的实现

随机的水平和竖直方向翻转

对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip()torchvision.transforms.RandomVerticalFlip()

# 随机水平翻转
h_filp = tfs.RandomHorizontalFlip()(im)
h_filp

深度学习入门之Pytorch 数据增强的实现

# 随机竖直翻转
v_flip = tfs.RandomVerticalFlip()(im)
v_flip

深度学习入门之Pytorch 数据增强的实现

随机角度旋转

一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation() 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转

rot_im = tfs.RandomRotation(45)(im)
rot_im

深度学习入门之Pytorch 数据增强的实现

亮度、对比度和颜色的变化

除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色

# 亮度
bright_im = tfs.ColorJitter(brightness=1)(im) # 随机从 0 ~ 2 之间亮度变化,1 表示原图
bright_im

深度学习入门之Pytorch 数据增强的实现

# 对比度
contrast_im = tfs.ColorJitter(contrast=1)(im) # 随机从 0 ~ 2 之间对比度变化,1 表示原图
contrast_im

深度学习入门之Pytorch 数据增强的实现

# 颜色
color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化
color_im

深度学习入门之Pytorch 数据增强的实现

上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 torchvision.transforms.Compose(),下面我们举个例子

im_aug = tfs.Compose([
  tfs.Resize(120),
  tfs.RandomHorizontalFlip(),
  tfs.RandomCrop(96),
  tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
  for j in range(ncols):
    figs[i][j].imshow(im_aug(im))
    figs[i][j].axes.get_xaxis().set_visible(False)
    figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

深度学习入门之Pytorch 数据增强的实现

可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据
下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用 ResNet 进行训练

使用数据增强

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs
# 使用数据增强
def train_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(120),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(96),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

def test_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

不使用数据增强

# 不使用数据增强
def data_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。

而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。

以上就是深度学习入门之Pytorch 数据增强的实现的详细内容,更多关于Pytorch 数据增强的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python读取Android permission文件
Nov 01 Python
python实现360皮肤按钮控件示例
Feb 21 Python
Python采用raw_input读取输入值的方法
Aug 18 Python
python多线程编程中的join函数使用心得
Sep 02 Python
对Tensorflow中的变量初始化函数详解
Jul 27 Python
python实现二维插值的三维显示
Dec 17 Python
解决PyCharm控制台输出乱码的问题
Jan 16 Python
Django ModelForm组件使用方法详解
Jul 23 Python
python线程定时器Timer实现原理解析
Nov 30 Python
pandas实现将日期转换成timestamp
Dec 07 Python
django form和field具体方法和属性说明
Jul 09 Python
Scrapy爬虫文件批量运行的实现
Sep 30 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 #Python
python 回溯法模板详解
Feb 26 #Python
python实现信号时域统计特征提取代码
Feb 26 #Python
Python 基于FIR实现Hilbert滤波器求信号包络详解
Feb 26 #Python
python实现逆滤波与维纳滤波示例
Feb 26 #Python
Python全面分析系统的时域特性和频率域特性
Feb 26 #Python
解决pycharm每次打开项目都需要配置解释器和安装库问题
Feb 26 #Python
You might like
php读取javascript设置的cookies的代码
2010/04/12 PHP
PHP中PDO基础教程 入门级
2011/09/04 PHP
php读取mysql乱码,用set names XXX解决的原理分享
2011/12/29 PHP
解决PHP4.0 和 PHP5.0类构造函数的兼容问题
2013/08/01 PHP
php下载excel无法打开的解决方法
2013/12/24 PHP
win7计划任务定时执行PHP脚本设置图解
2014/05/09 PHP
php防止用户重复提交表单
2015/11/02 PHP
php实现简单的上传进度条
2015/11/17 PHP
全新Mac配置PHP开发环境教程
2016/02/03 PHP
php进程间通讯实例分析
2016/07/11 PHP
仿迅雷焦点广告效果(JQuery版)
2008/11/19 Javascript
javascript:history.go()和History.back()的区别及应用
2012/11/25 Javascript
使用js完成节点的增删改复制等的操作
2014/01/02 Javascript
根据当前时间在jsp页面上显示上午或下午
2014/08/18 Javascript
PHP使用方法重载实现动态创建属性的get和set方法
2014/11/17 Javascript
举例详解Python中smtplib模块处理电子邮件的使用
2015/06/24 Javascript
jQuery实现页面评论栏中访客信息自动填写功能的方法
2016/05/23 Javascript
jQuery 跨域访问解决原理案例详解
2016/07/09 Javascript
使用Node.js实现简易MVC框架的方法
2017/08/07 Javascript
JavaScript判断输入是否为数字类型的方法总结
2017/09/28 Javascript
详解webpack之scss和postcss-loader的配置
2018/01/09 Javascript
webstorm添加*.vue文件支持
2018/05/08 Javascript
jQuery实现简单三级联动效果
2020/09/05 jQuery
[03:49]DOTA2 2015国际邀请赛中国区预选赛第二日现场百态
2015/05/27 DOTA
wxpython 学习笔记 第一天
2009/02/09 Python
基于Python的XSS测试工具XSStrike使用方法
2017/07/29 Python
Python下调用Linux的Shell命令的方法
2018/06/12 Python
Python字符串对齐、删除字符串不需要的内容以及格式化打印字符
2021/01/23 Python
HTML5实现应用程序缓存(Application Cache)
2020/06/16 HTML / CSS
Chicco婴儿用品美国官网:汽车座椅、婴儿推车、高脚椅等
2018/11/05 全球购物
Oral-B荷兰:牙医最推荐的品牌
2020/02/25 全球购物
大学生新学期计划书
2014/04/28 职场文书
北大自主招生自荐信
2015/03/04 职场文书
干部培训简讯
2015/07/20 职场文书
导游词之广东佛山(南风古灶)
2019/09/24 职场文书
画错魏国疆域啦!《派对咖孔明》动画因作画失误于官网致歉
2022/04/07 日漫