pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门篇之条件、循环
Oct 17 Python
python多线程操作实例
Nov 21 Python
Python实现字典的key和values的交换
Aug 04 Python
Python实现网络端口转发和重定向的方法
Sep 19 Python
python xml.etree.ElementTree遍历xml所有节点实例详解
Dec 04 Python
Python实现的直接插入排序算法示例
Apr 29 Python
Python面向对象类编写细节分析【类,方法,继承,超类,接口等】
Jan 05 Python
详解Django+uwsgi+Nginx上线最佳实战
Mar 14 Python
Python实现转换图片背景颜色代码
Apr 30 Python
Python Json数据文件操作原理解析
May 09 Python
Python如何爬取51cto数据并存入MySQL
Aug 25 Python
浅谈anaconda python 版本对应关系
Oct 07 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
PHP 引用文件技巧
2010/03/02 PHP
通过5个php实例细致说明传值与传引用的区别
2012/08/08 PHP
PHP批量去除BOM头内容信息代码
2016/03/11 PHP
WEB 浏览器兼容 推荐收藏
2010/05/14 Javascript
js操作select控件的几种方法
2010/06/02 Javascript
利用google提供的API(JavaScript接口)获取网站访问者IP地理位置的代码详解
2010/07/24 Javascript
jquery索引在使用中的一些困惑
2013/10/24 Javascript
jquery给图片添加鼠标经过时的边框效果
2013/11/12 Javascript
用unescape反编码得出汉字示例
2014/04/24 Javascript
JQuery异步加载PartialView的方法
2016/06/07 Javascript
jQuery自适应轮播图插件Swiper用法示例
2016/08/24 Javascript
Angular HMR(热模块替换)功能实现方法
2018/04/04 Javascript
详解vue使用vue-layer-mobile组件实现toast,loading效果
2018/08/31 Javascript
详解在vue-cli中使用graphql即vue-apollo的用法
2018/09/08 Javascript
在layui下对元素进行事件绑定的实例
2019/09/06 Javascript
JavaScript代理模式原理与用法实例详解
2020/03/10 Javascript
[01:38]完美世界DOTA2联赛(PWL)宣传片:第一站
2020/10/26 DOTA
[46:49]完美世界DOTA2联赛PWL S3 access vs Rebirth 第二场 12.19
2020/12/24 DOTA
Python编程之微信推送模板消息功能示例
2017/08/21 Python
同时安装Python2 & Python3 cmd下版本自由选择的方法
2017/12/09 Python
python 列表递归求和、计数、求最大元素的实例
2018/11/28 Python
CSS3改变浏览器滚动条样式
2019/01/04 HTML / CSS
雅诗兰黛旗下专业男士保养领导品牌:Lab Series
2017/05/15 全球购物
贝斯特韦斯特酒店集团官网:Best Western
2019/01/03 全球购物
最好的商品表达自己:Cafepress
2019/09/04 全球购物
西班牙鞋子和箱包在线销售网站:zapatos.es
2020/02/17 全球购物
泰国时尚电商:POMELO Fashion
2020/03/11 全球购物
销售工作人员的自我评价分享
2013/11/10 职场文书
大学自我评价
2014/02/12 职场文书
毕业生写求职信的要点
2014/03/04 职场文书
教师对照四风自我剖析材料
2014/09/30 职场文书
防灾减灾标语
2014/10/07 职场文书
离婚协议书范文2014
2014/10/16 职场文书
2019幼儿园感恩节活动策划书
2019/11/28 职场文书
老生常谈 使用 CSS 实现三角形的技巧(多种方法)
2021/04/13 HTML / CSS
【D4DJ】美少女DJ企划 动画将于明年冬季开播第2季
2022/04/11 日漫