pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的迭代器漫谈
Feb 03 Python
在Python编程过程中用单元测试法调试代码的介绍
Apr 02 Python
在Python中关于中文编码问题的处理建议
Apr 08 Python
python控制台中实现进度条功能
Nov 10 Python
bpython 功能强大的Python shell
Feb 16 Python
itchat和matplotlib的结合使用爬取微信信息的实例
Aug 25 Python
NLTK 3.2.4 环境搭建教程
Sep 19 Python
更改Python的pip install 默认安装依赖路径方法详解
Oct 27 Python
python覆盖写入,追加写入的实例
Jun 26 Python
Django单元测试工具test client使用详解
Aug 02 Python
python3中for循环踩过的坑记录
Dec 14 Python
使用numpy实现矩阵的翻转(flip)与旋转
Jun 03 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
php面向对象全攻略 (十一)__toString()用法 克隆对象 __call处理调用错误
2009/09/30 PHP
让PHP显示Facebook的粉丝数量方法
2014/01/08 PHP
PHP实现使用DOM将XML数据存入数组的方法示例
2017/09/27 PHP
详解php curl带有csrf-token验证模拟提交方法
2018/04/18 PHP
会自动逐行上升的文本框
2006/06/30 Javascript
JS 创建对象(常见的几种方法)
2008/11/03 Javascript
JavaScript Prototype对象
2009/01/07 Javascript
复制小说文本时出现的随机乱码的去除方法
2010/09/07 Javascript
关于Javascript 对象(object)的prototype
2014/05/09 Javascript
为什么JS中eval处理JSON数据要加括号
2015/04/13 Javascript
js实现YouKu的漂亮搜索框效果
2015/08/19 Javascript
跟我学习javascript解决异步编程异常方案
2015/11/23 Javascript
如何用JS判断两个数字的大小
2016/07/21 Javascript
JavaScript随机打乱数组顺序之随机洗牌算法
2016/08/02 Javascript
react 父组件与子组件之间的值传递的方法
2017/09/14 Javascript
vue.js如何将echarts封装为组件一键使用详解
2017/10/10 Javascript
vue设计一个倒计时秒杀的组件详解
2019/04/06 Javascript
elementUi vue el-radio 监听选中变化的实例代码
2019/06/28 Javascript
JS实现打砖块游戏
2020/02/14 Javascript
vue组件系列之TagsInput详解
2020/05/14 Javascript
[01:20]2018DOTA2亚洲邀请赛总决赛战队LGD晋级之路
2018/04/07 DOTA
Python获取单个程序CPU使用情况趋势图
2015/03/10 Python
在VS Code上搭建Python开发环境的方法
2018/04/06 Python
Python学习笔记之自定义函数用法详解
2019/06/08 Python
windows下安装Python虚拟环境virtualenvwrapper-win
2019/06/14 Python
基于Python解密仿射密码
2019/10/21 Python
有关Tensorflow梯度下降常用的优化方法分享
2020/02/04 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
2020/02/29 Python
基于python爬取有道翻译过程图解
2020/03/31 Python
解决pycharm不能自动保存在远程linux中的问题
2021/02/06 Python
Noon埃及:埃及在线购物
2019/11/26 全球购物
根叔历年演讲稿
2014/05/20 职场文书
英语导游词
2015/02/13 职场文书
七年级之家长会发言稿范文
2019/09/04 职场文书
Python字典和列表性能之间的比较
2021/06/07 Python
详解Vue的列表渲染
2021/11/20 Vue.js