深度学习入门之Pytorch 数据增强的实现


Posted in Python onFebruary 26, 2020

数据增强

卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。

2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。

常用的数据增强方法

常用的数据增强方法如下:
1.对图片进行一定比例缩放
2.对图片进行随机位置的截取
3.对图片进行随机的水平和竖直翻转
4.对图片进行随机角度的旋转
5.对图片进行亮度、对比度和颜色的随机变化

这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法。

import sys
sys.path.append('..')

from PIL import Image
from torchvision import transforms as tfs

# 读入一张图片
im = Image.open('./cat.png')
im

深度学习入门之Pytorch 数据增强的实现

随机比例放缩

随机比例缩放主要使用的是 torchvision.transforms.Resize() 这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看文档

# 比例缩放
print('before scale, shape: {}'.format(im.size))
new_im = tfs.Resize((100, 200))(im)
print('after scale, shape: {}'.format(new_im.size))
new_im

深度学习入门之Pytorch 数据增强的实现

随机位置截取

随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop(),同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取

# 随机裁剪出 100 x 100 的区域
random_im1 = tfs.RandomCrop(100)(im)
random_im1

深度学习入门之Pytorch 数据增强的实现

# 中心裁剪出 100 x 100 的区域
center_im = tfs.CenterCrop(100)(im)
center_im

深度学习入门之Pytorch 数据增强的实现

随机的水平和竖直方向翻转

对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip()torchvision.transforms.RandomVerticalFlip()

# 随机水平翻转
h_filp = tfs.RandomHorizontalFlip()(im)
h_filp

深度学习入门之Pytorch 数据增强的实现

# 随机竖直翻转
v_flip = tfs.RandomVerticalFlip()(im)
v_flip

深度学习入门之Pytorch 数据增强的实现

随机角度旋转

一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation() 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转

rot_im = tfs.RandomRotation(45)(im)
rot_im

深度学习入门之Pytorch 数据增强的实现

亮度、对比度和颜色的变化

除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色

# 亮度
bright_im = tfs.ColorJitter(brightness=1)(im) # 随机从 0 ~ 2 之间亮度变化,1 表示原图
bright_im

深度学习入门之Pytorch 数据增强的实现

# 对比度
contrast_im = tfs.ColorJitter(contrast=1)(im) # 随机从 0 ~ 2 之间对比度变化,1 表示原图
contrast_im

深度学习入门之Pytorch 数据增强的实现

# 颜色
color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化
color_im

深度学习入门之Pytorch 数据增强的实现

上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 torchvision.transforms.Compose(),下面我们举个例子

im_aug = tfs.Compose([
  tfs.Resize(120),
  tfs.RandomHorizontalFlip(),
  tfs.RandomCrop(96),
  tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
  for j in range(ncols):
    figs[i][j].imshow(im_aug(im))
    figs[i][j].axes.get_xaxis().set_visible(False)
    figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

深度学习入门之Pytorch 数据增强的实现

可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据
下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用 ResNet 进行训练

使用数据增强

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs
# 使用数据增强
def train_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(120),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(96),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

def test_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

不使用数据增强

# 不使用数据增强
def data_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。

而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。

以上就是深度学习入门之Pytorch 数据增强的实现的详细内容,更多关于Pytorch 数据增强的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python脚本在Appium库上对移动应用实现自动化测试
Apr 17 Python
python在指定目录下查找gif文件的方法
May 04 Python
Python实现简单字典树的方法
Apr 29 Python
Python使用django框架实现多人在线匿名聊天的小程序
Nov 29 Python
Jupyter安装nbextensions,启动提示没有nbextensions库
Apr 23 Python
Python实现通过解析域名获取ip地址的方法分析
May 17 Python
Python3 解决读取中文文件txt编码的问题
Dec 20 Python
浅谈Python中range与Numpy中arange的比较
Mar 11 Python
Django多数据库配置及逆向生成model教程
Mar 28 Python
在python中利用pycharm自定义代码块教程(三步搞定)
Apr 15 Python
Python爬虫如何破解JS加密的Cookie
Nov 19 Python
Github 使用python对copilot做些简单使用测试
Apr 14 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 #Python
python 回溯法模板详解
Feb 26 #Python
python实现信号时域统计特征提取代码
Feb 26 #Python
Python 基于FIR实现Hilbert滤波器求信号包络详解
Feb 26 #Python
python实现逆滤波与维纳滤波示例
Feb 26 #Python
Python全面分析系统的时域特性和频率域特性
Feb 26 #Python
解决pycharm每次打开项目都需要配置解释器和安装库问题
Feb 26 #Python
You might like
PHP中实现生成静态文件的方法缓解服务器压力
2014/01/07 PHP
php实现兼容2038年后Unix时间戳转换函数
2015/03/18 PHP
PHP表单提交后引号前自动加反斜杠的原因及三种办法关闭php魔术引号
2015/09/30 PHP
javascript replace方法与正则表达式
2008/02/19 Javascript
入门基础学习 ExtJS笔记(一)
2010/11/11 Javascript
javascript验证只能输入数字和一个小数点示例
2013/10/21 Javascript
JS中三目运算符和if else的区别分析与示例
2014/11/21 Javascript
纯js实现重发验证码按钮倒数功能
2015/04/21 Javascript
js实现根据身份证号自动生成出生日期
2015/12/15 Javascript
jQuery插件FusionCharts实现的3D柱状图效果实例【附demo源码下载】
2017/03/03 Javascript
微信小程序实现滑动删除效果
2017/05/19 Javascript
Node.js+jade抓取博客所有文章生成静态html文件的实例
2017/09/19 Javascript
js数组常用最重要的方法
2018/02/04 Javascript
详解js的作用域、预解析机制
2018/02/05 Javascript
在vue中使用axios实现post方式获取二进制流下载文件(实例代码)
2019/12/16 Javascript
vue 微信扫码登录(自定义样式)
2020/01/06 Javascript
es6中new.target的作用和使用场景简单示例分析
2020/03/14 Javascript
vue+element实现动态加载表单
2020/12/13 Vue.js
在Python的web框架中中编写日志列表的教程
2015/04/30 Python
Python2.x版本中基本的中文编码问题解决
2015/10/12 Python
wxpython中Textctrl回车事件无效的解决方法
2016/07/21 Python
matlab中二维插值函数interp2的使用详解
2020/04/22 Python
浅谈Keras的Sequential与PyTorch的Sequential的区别
2020/06/17 Python
法国发饰品牌:Alexandre De Paris
2018/12/04 全球购物
英国泽西岛植物:Jersey Plants Direct
2019/08/07 全球购物
同步和异步有何异同,在什么情况下分别使用他们?
2012/12/28 面试题
档案接收函
2014/01/13 职场文书
中学劳技课教师的自我评价
2014/02/05 职场文书
奥巴马当选演讲稿
2014/09/10 职场文书
合伙经营协议书范本
2014/09/13 职场文书
高考作弊检讨书1500字
2015/02/16 职场文书
火烧圆明园观后感
2015/06/03 职场文书
2016年师德师风学习心得体会
2016/01/12 职场文书
2019入党申请书范文3篇
2019/08/21 职场文书
JS Canvas接口和动画效果大全
2021/04/29 Javascript
Oracle 触发器trigger使用案例
2022/02/24 Oracle