pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之函数参数的多种传递方法
Aug 30 Python
Python3实现转换Image图片格式
Jun 21 Python
python/sympy求解矩阵方程的方法
Nov 08 Python
浅谈python的深浅拷贝以及fromkeys的用法
Mar 08 Python
Python实现的爬取小说爬虫功能示例
Mar 30 Python
Python中print和return的作用及区别解析
May 05 Python
python 自定义装饰器实例详解
Jul 20 Python
Django配置文件代码说明
Dec 04 Python
PyTorch的SoftMax交叉熵损失和梯度用法
Jan 15 Python
python实现快递价格查询系统
Mar 03 Python
解决Python安装cryptography报错问题
Sep 03 Python
python搜索算法原理及实例讲解
Nov 18 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
php的dl函数用法实例
2014/11/06 PHP
js判断变量是否空值的代码
2008/10/26 Javascript
jquery 输入框数字限制插件
2009/11/10 Javascript
JavaScript高级程序设计 读书笔记之十一 内置对象Global
2012/03/07 Javascript
js弹出层包含flash 不能关闭隐藏的2种处理方法
2013/06/17 Javascript
javascript获取设置div的高度和宽度兼容任何浏览器
2013/09/22 Javascript
JavaScript调用客户端的可执行文件(示例代码)
2013/11/28 Javascript
JavaScript中的单引号和双引号报错的解决方法
2014/09/01 Javascript
JavaScript中判断两个字符串是否相等的方法
2015/07/07 Javascript
巧用jQuery选择器提高写表单效率的方法
2016/08/19 Javascript
原生js轮播特效
2017/05/18 Javascript
详解JavaScript数组过滤相同元素的5种方法
2017/05/23 Javascript
vue绑定class与行间样式style详解
2017/08/16 Javascript
jQuery md5加密插件jQuery.md5.js用法示例
2018/08/24 jQuery
vue项目前端知识点整理【收藏】
2019/05/13 Javascript
在Vue中使用this.$store或者是$route一直报错的解决
2019/11/08 Javascript
在vue中使用防抖和节流,防止重复点击或重复上拉加载实例
2019/11/13 Javascript
详解element上传组件before-remove钩子问题解决
2020/04/08 Javascript
JavaScript动态生成表格的示例
2020/11/02 Javascript
[25:45]2018DOTA2亚洲邀请赛4.5SOLO赛 Sylar vs Paparazi
2018/04/06 DOTA
python脚本设置系统时间的两种方法
2016/02/21 Python
python3使用PyMysql连接mysql数据库实例
2017/02/07 Python
python selenium爬取斗鱼所有直播房间信息过程详解
2019/08/09 Python
接口自动化多层嵌套json数据处理代码实例
2020/11/20 Python
关于多种方式完美解决Python pip命令下载第三方库的问题
2020/12/21 Python
css3动画鼠标放上图片逐渐变大鼠标离开图片逐渐缩小效果
2021/01/27 HTML / CSS
医学院四年学习生活的自我评价
2013/11/06 职场文书
班级学习雷锋活动总结
2014/07/04 职场文书
ktv好的活动方案
2014/08/15 职场文书
我心目中的好老师活动方案
2014/08/19 职场文书
2014领导干部四风问题查摆思想汇报
2014/09/13 职场文书
乡镇党的群众路线教育实践活动剖析材料
2014/10/09 职场文书
2014年法院工作总结
2014/11/24 职场文书
七年级作文之环保作文
2019/10/17 职场文书
python中字符串String及其常见操作指南(方法、函数)
2022/04/06 Python
CSS 鼠标选中文字后改变背景色的实现代码
2023/05/21 HTML / CSS