pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
部署Python的框架下的web app的详细教程
Apr 30 Python
利用python发送和接收邮件
Sep 27 Python
python的构建工具setup.py的方法使用示例
Oct 23 Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 Python
Python实现的建造者模式示例
Aug 06 Python
python实现键盘控制鼠标移动
Nov 27 Python
python下PyGame的下载与安装过程及遇到问题
Aug 04 Python
简单易懂Pytorch实战实例VGG深度网络
Aug 27 Python
python基于三阶贝塞尔曲线的数据平滑算法
Dec 27 Python
Python 实现向word(docx)中输出
Feb 13 Python
python对Excel的读取的示例代码
Feb 14 Python
PyCharm取消波浪线、下划线和中划线的实现
Mar 03 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
mysql下创建字段并设置主键的php代码
2010/05/16 PHP
PHP递归算法的详细示例分析
2013/02/19 PHP
10款实用的PHP开源工具
2015/10/23 PHP
学习php设计模式 php实现建造者模式
2015/12/07 PHP
yii2中LinkPager增加总页数和总记录数的实例
2017/08/28 PHP
visual studio code 调试php方法(图文详解)
2017/09/15 PHP
PHP数组Key强制类型转换实现原理解析
2020/09/01 PHP
从零开始学习jQuery (四) jQuery中操作元素的属性与样式
2011/02/23 Javascript
js字符串转成JSON
2013/11/07 Javascript
js实现日历可获得指定日期周数及星期几示例分享(js获取星期几)
2014/03/14 Javascript
js中的getAttribute方法使用示例
2014/08/01 Javascript
详谈jQuery操纵DOM元素属性 attr()和removeAtrr()方法
2015/01/22 Javascript
超详细的javascript数组方法汇总
2015/11/21 Javascript
通过设置CSS中的position属性来固定层的位置
2015/12/14 Javascript
全面了解JavaScript对象进阶
2016/07/19 Javascript
jQuery中$.grep() 过滤函数 数组过滤
2016/11/22 Javascript
简单谈谈vue的过渡动画(推荐)
2017/10/11 Javascript
jQuery 导航自动跟随滚动的实现代码
2018/05/30 jQuery
了解JavaScript中let语句
2019/05/30 Javascript
微信小程序实现图片选择并预览功能
2019/07/25 Javascript
js判断非127开头的IP地址的实例代码
2020/01/05 Javascript
[34:27]DOTA2上海特级锦标赛B组败者赛 VG VS Spirit第一局
2016/02/26 DOTA
Python 中字符串拼接的多种方法
2018/07/30 Python
对python调用RPC接口的实例详解
2019/01/03 Python
python 浅谈serial与stm32通信的编码问题
2019/12/18 Python
利用Python脚本批量生成SQL语句
2020/03/04 Python
浅谈HTML5 服务器推送事件(Server-sent Events)
2017/08/01 HTML / CSS
Glamest意大利:女性在线奢侈品零售店
2019/04/28 全球购物
党的群众路线教育实践活动心得体会
2014/03/03 职场文书
2014年公司庆元旦活动方案
2014/03/05 职场文书
元旦联欢会主持词
2014/03/26 职场文书
《回乡偶书》教学反思
2014/04/12 职场文书
大学优秀班主任事迹材料
2014/05/02 职场文书
医学求职自荐信
2014/06/21 职场文书
习近平在党的群众路线教育实践活动总结大会上的讲话
2014/10/21 职场文书
SQL中的三种去重方法小结
2021/11/01 SQL Server