深度学习入门之Pytorch 数据增强的实现


Posted in Python onFebruary 26, 2020

数据增强

卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。

2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。

常用的数据增强方法

常用的数据增强方法如下:
1.对图片进行一定比例缩放
2.对图片进行随机位置的截取
3.对图片进行随机的水平和竖直翻转
4.对图片进行随机角度的旋转
5.对图片进行亮度、对比度和颜色的随机变化

这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法。

import sys
sys.path.append('..')

from PIL import Image
from torchvision import transforms as tfs

# 读入一张图片
im = Image.open('./cat.png')
im

深度学习入门之Pytorch 数据增强的实现

随机比例放缩

随机比例缩放主要使用的是 torchvision.transforms.Resize() 这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看文档

# 比例缩放
print('before scale, shape: {}'.format(im.size))
new_im = tfs.Resize((100, 200))(im)
print('after scale, shape: {}'.format(new_im.size))
new_im

深度学习入门之Pytorch 数据增强的实现

随机位置截取

随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop(),同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取

# 随机裁剪出 100 x 100 的区域
random_im1 = tfs.RandomCrop(100)(im)
random_im1

深度学习入门之Pytorch 数据增强的实现

# 中心裁剪出 100 x 100 的区域
center_im = tfs.CenterCrop(100)(im)
center_im

深度学习入门之Pytorch 数据增强的实现

随机的水平和竖直方向翻转

对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip()torchvision.transforms.RandomVerticalFlip()

# 随机水平翻转
h_filp = tfs.RandomHorizontalFlip()(im)
h_filp

深度学习入门之Pytorch 数据增强的实现

# 随机竖直翻转
v_flip = tfs.RandomVerticalFlip()(im)
v_flip

深度学习入门之Pytorch 数据增强的实现

随机角度旋转

一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation() 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转

rot_im = tfs.RandomRotation(45)(im)
rot_im

深度学习入门之Pytorch 数据增强的实现

亮度、对比度和颜色的变化

除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色

# 亮度
bright_im = tfs.ColorJitter(brightness=1)(im) # 随机从 0 ~ 2 之间亮度变化,1 表示原图
bright_im

深度学习入门之Pytorch 数据增强的实现

# 对比度
contrast_im = tfs.ColorJitter(contrast=1)(im) # 随机从 0 ~ 2 之间对比度变化,1 表示原图
contrast_im

深度学习入门之Pytorch 数据增强的实现

# 颜色
color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化
color_im

深度学习入门之Pytorch 数据增强的实现

上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 torchvision.transforms.Compose(),下面我们举个例子

im_aug = tfs.Compose([
  tfs.Resize(120),
  tfs.RandomHorizontalFlip(),
  tfs.RandomCrop(96),
  tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
  for j in range(ncols):
    figs[i][j].imshow(im_aug(im))
    figs[i][j].axes.get_xaxis().set_visible(False)
    figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

深度学习入门之Pytorch 数据增强的实现

可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据
下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用 ResNet 进行训练

使用数据增强

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs
# 使用数据增强
def train_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(120),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(96),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

def test_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

不使用数据增强

# 不使用数据增强
def data_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。

而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。

以上就是深度学习入门之Pytorch 数据增强的实现的详细内容,更多关于Pytorch 数据增强的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python操作摄像头截图实现远程监控的例子
Mar 25 Python
Python写的Tkinter程序屏幕居中方法
Mar 10 Python
python中django框架通过正则搜索页面上email地址的方法
Mar 21 Python
python入门基础之用户输入与模块初认识
Nov 14 Python
TensorFlow搭建神经网络最佳实践
Mar 09 Python
python实现图片筛选程序
Oct 24 Python
python实现Virginia无密钥解密
Mar 20 Python
python实现复制大量文件功能
Aug 31 Python
Python StringIO如何在内存中读写str
Jan 07 Python
Python对称的二叉树多种思路实现方法
Feb 28 Python
详解Python中@staticmethod和@classmethod区别及使用示例代码
Dec 14 Python
如何用python实现一个HTTP连接池
Jan 14 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 #Python
python 回溯法模板详解
Feb 26 #Python
python实现信号时域统计特征提取代码
Feb 26 #Python
Python 基于FIR实现Hilbert滤波器求信号包络详解
Feb 26 #Python
python实现逆滤波与维纳滤波示例
Feb 26 #Python
Python全面分析系统的时域特性和频率域特性
Feb 26 #Python
解决pycharm每次打开项目都需要配置解释器和安装库问题
Feb 26 #Python
You might like
PHP3 safe_mode 失效漏洞
2006/10/09 PHP
用PHP查询域名状态whois的类
2006/11/25 PHP
cache_lite试用
2007/02/14 PHP
php实现的css文件背景图片下载器代码
2014/11/11 PHP
php中错误处理操作实例分析
2019/08/23 PHP
jQuery的学习步骤
2011/02/23 Javascript
用js写了一个类似php的print_r输出换行功能
2013/02/18 Javascript
在Iframe中获取父窗口中表单的值(示例代码)
2013/11/22 Javascript
jquery实现的伪分页效果代码
2015/10/29 Javascript
Vue.js实现移动端短信验证码功能
2017/03/29 Javascript
JS去掉字符串前后空格、阻止表单提交的实现代码
2017/06/08 Javascript
JavaScript指定断点操作实例教程
2018/09/18 Javascript
Vuex的基本概念、项目搭建以及入坑点
2018/11/04 Javascript
Vue项目使用localStorage+Vuex保存用户登录信息
2019/05/27 Javascript
vue实现鼠标移过出现下拉二级菜单功能
2019/12/12 Javascript
js实现拖动缓动效果
2020/01/13 Javascript
JQuery通过键盘控制键盘按下与松开触发事件
2020/08/07 jQuery
[01:20]DOTA2上海特级锦标赛现场采访:谁的ID最受青睐
2016/03/25 DOTA
安装Python的教程-Windows
2017/07/22 Python
详解pyqt5 动画在QThread线程中无法运行问题
2018/05/05 Python
python 编码规范整理
2018/05/05 Python
Python可变和不可变、类的私有属性实例分析
2019/05/31 Python
python集合常见运算案例解析
2019/10/17 Python
可自定义箭头样式的CSS3气泡提示框
2016/03/16 HTML / CSS
Lyle & Scott苏格兰金鹰官网:英国皇室御用品牌
2018/05/09 全球购物
Athleta官网:购买女士瑜伽服、技术运动服和休闲运动服
2020/11/12 全球购物
JRE、JDK、JVM之间的关系怎样
2012/05/16 面试题
护士自我鉴定范文
2013/10/06 职场文书
数控技术学生的自我评价
2014/02/15 职场文书
2014年党支部承诺书
2014/05/30 职场文书
校友回访母校寄语
2015/02/26 职场文书
2015年女工委工作总结
2015/07/27 职场文书
高一数学教学反思
2016/02/18 职场文书
2019职场实习报告该怎么写?
2019/07/01 职场文书
一劳永逸彻底解决pip install慢的办法
2021/05/24 Python
Spring事务管理下synchronized锁失效问题的解决方法
2022/03/31 Java/Android