pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用strip()方法删除字符串中空格的教程
May 20 Python
python简单获取本机计算机名和IP地址的方法
Jun 03 Python
在Django中限制已登录用户的访问的方法
Jul 23 Python
教你使用python画一朵花送女朋友
Mar 29 Python
详解python实现识别手写MNIST数字集的程序
Aug 03 Python
python 检查文件mime类型的方法
Dec 08 Python
Python 按字典dict的键排序,并取出相应的键值放于list中的实例
Feb 12 Python
Python和Java的语法对比分析语法简洁上python的确完美胜出
May 10 Python
pywinauto自动化操作记事本
Aug 26 Python
Python3+Django get/post请求实现教程详解
Feb 16 Python
解决TensorFlow训练模型及保存数量限制的问题
Mar 03 Python
python如何进行基准测试
Apr 26 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
php中file_get_contents与curl性能比较分析
2014/11/08 PHP
PHP下使用mysqli的函数连接mysql出现warning: mysqli::real_connect(): (hy000/1040): ...
2016/02/14 PHP
在html页面上拖放移动标签
2010/01/08 Javascript
JS命名空间的另一种实现
2013/08/09 Javascript
js获取内联样式的方法
2015/01/27 Javascript
javascript转换日期字符串为Date日期对象的方法
2015/02/13 Javascript
跟我学习javascript的arguments对象
2015/11/16 Javascript
JS实现设置ff与ie元素绝对位置的方法
2016/03/08 Javascript
深入理解JavaScript中的call、apply、bind方法的区别
2016/05/30 Javascript
总结Javascript中的隐式类型转换
2016/08/24 Javascript
jQuery文本框得到与失去焦点动态改变样式效果
2016/09/08 Javascript
KnockoutJS 3.X API 第四章之数据控制流if绑定和ifnot绑定
2016/10/10 Javascript
详解vue使用插槽分发内容slot的用法
2019/03/28 Javascript
通过实例了解Javascript柯里化流程
2020/03/03 Javascript
微信小程序实现自定义动画弹框/提示框的方法实例
2020/11/06 Javascript
Windows系统下安装Python的SSH模块教程
2015/02/05 Python
在Python中关于中文编码问题的处理建议
2015/04/08 Python
Python实现批量将word转html并将html内容发布至网站的方法
2015/07/14 Python
编写Python爬虫抓取豆瓣电影TOP100及用户头像的方法
2016/01/20 Python
python实现Decorator模式实例代码
2018/02/09 Python
python2.6.6如何升级到python2.7.14
2018/04/08 Python
Apache部署Django项目图文详解
2019/07/30 Python
解决Pycharm 导入其他文件夹源码的2种方法
2020/02/12 Python
python 对任意数据和曲线进行拟合并求出函数表达式的三种解决方案
2020/02/18 Python
python代码xml转txt实例
2020/03/10 Python
keras实现基于孪生网络的图片相似度计算方式
2020/06/11 Python
Django 用户认证Auth组件的使用
2020/11/30 Python
python通用数据库操作工具 pydbclib的使用简介
2020/12/21 Python
基于HTML5+tracking.js实现刷脸支付功能
2020/04/16 HTML / CSS
澳大利亚先进的皮肤和激光诊所购物网站:Soho Skincare
2018/10/15 全球购物
北京天润融通.net面试题笔试题
2012/02/20 面试题
Linux不知道文件后缀名怎么判断文件类型
2014/08/21 面试题
优秀幼教自荐信
2014/02/03 职场文书
生活委员竞选稿
2015/11/21 职场文书
Redis实现订单自动过期功能的示例代码
2021/05/08 Redis
Java Spring Lifecycle的使用
2022/05/06 Java/Android