pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用os模块的os.walk遍历文件夹示例
Jan 27 Python
Python写的一个简单监控系统
Jun 19 Python
python用Pygal如何生成漂亮的SVG图像详解
Feb 10 Python
浅谈Python2获取中文文件名的编码问题
Jan 09 Python
python调用c++ ctype list传数组或者返回数组的方法
Feb 13 Python
Python数据结构与算法(几种排序)小结
Jun 22 Python
Django1.11自带分页器paginator的使用方法
Oct 31 Python
python做接口测试的必要性
Nov 20 Python
pytorch 实现模型不同层设置不同的学习率方式
Jan 06 Python
python实现ip地址的包含关系判断
Feb 07 Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 Python
使用Python实现微信拍一拍功能的思路代码
Jul 09 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
PHP与MySQL交互使用详解
2006/10/09 PHP
php include类文件超时问题处理
2015/02/06 PHP
php调用KyotoTycoon简单实例
2015/04/02 PHP
PHP判断FORM表单或URL参数来的数据是否为整数的方法
2016/03/25 PHP
php版微信公众平台入门教程之开发者认证的方法
2016/09/26 PHP
JavaScript对象链式操作代码(jquery)
2010/07/04 Javascript
JS关闭窗口与JS关闭页面的几种方法小结
2013/12/17 Javascript
浅谈轻量级js模板引擎simplite
2015/02/13 Javascript
jquery 表单验证之通过 class验证表单不为空
2015/11/02 Javascript
基于JS实现textarea中获取动态剩余字数的方法
2016/05/25 Javascript
AngularJS表单详解及示例代码
2016/08/17 Javascript
jQuery实现弹窗居中效果类似alert()
2017/02/27 Javascript
bootstrap实现动态进度条效果
2017/03/08 Javascript
移动设备手势事件库Touch.js使用详解
2017/08/18 Javascript
详解JavaScript中的六种错误类型
2017/09/21 Javascript
利用原生js实现html5小游戏之打砖块(附源码)
2018/01/03 Javascript
微信小程序自定义组件实现tabs选项卡功能
2018/07/14 Javascript
vue项目打包后上传至GitHub并实现github-pages的预览
2019/05/06 Javascript
vue项目中使用scss的方法步骤
2019/05/16 Javascript
使用Vue实现一个树组件的示例
2020/11/06 Javascript
vue中配置scss全局变量的步骤
2020/12/28 Vue.js
Python变量赋值的秘密分享
2018/04/03 Python
python命令 -u参数用法解析
2019/10/24 Python
html5摇一摇代码优化包括DeviceMotionEvent等等
2014/09/01 HTML / CSS
处理textarea中的换行和空格
2019/12/12 HTML / CSS
英国版MAC彩妆品牌:Illamasqua
2018/04/18 全球购物
什么是"引用"?申明和使用"引用"要注意哪些问题?
2016/03/03 面试题
售前工程师职业生涯规划
2014/03/02 职场文书
新春联欢会主持词
2014/03/24 职场文书
市场部经理岗位职责
2014/04/10 职场文书
党支部先进事迹材料
2014/12/24 职场文书
向雷锋同志学习倡议书
2015/04/27 职场文书
名人传读书笔记
2015/06/26 职场文书
python入门之算法学习
2021/04/22 Python
python中requests库+xpath+lxml简单使用
2021/04/29 Python
使用canvas对video视频某一刻截图功能
2021/09/25 HTML / CSS