pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用urllib模块和pyquery实现阿里巴巴排名查询
Jan 16 Python
Python中super关键字用法实例分析
May 28 Python
Python爬虫设置代理IP的方法(爬虫技巧)
Mar 04 Python
django的登录注册系统的示例代码
May 14 Python
python爬取个性签名的方法
Jun 17 Python
python运行时强制刷新缓冲区的方法
Jan 14 Python
python的pygal模块绘制反正切函数图像方法
Jul 16 Python
python文字和unicode/ascll相互转换函数及简单加密解密实现代码
Aug 12 Python
Python3 获取文件属性的方式(时间、大小等)
Mar 12 Python
探秘TensorFlow 和 NumPy 的 Broadcasting 机制
Mar 13 Python
Python count函数使用方法实例解析
Mar 23 Python
20行Python代码实现视频字符化功能
Apr 13 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
DC动画很好看?新作烂得令人发指,名叫《红色之子》
2020/04/09 欧美动漫
PHP syntax error, unexpected $end 错误的一种原因及解决
2008/10/25 PHP
php自定义session示例分享
2014/04/22 PHP
yii框架通过控制台命令创建定时任务示例
2014/04/30 PHP
laravel实现一个上传图片的接口,并建立软链接,访问图片的方法
2019/10/12 PHP
javascript 面向对象全新理练之数据的封装
2009/12/03 Javascript
jQuery 三击事件实现代码
2013/09/11 Javascript
JQuery下拉框应用示例介绍
2014/04/23 Javascript
js实现网页收藏功能
2015/12/17 Javascript
js实现页面跳转的五种方法推荐
2016/03/10 Javascript
Bootstrap警告框(Alert)插件使用方法
2017/03/21 Javascript
JS中正则表达式全局匹配模式 /g用法详解
2017/04/01 Javascript
jQuery插件FusionCharts绘制2D环饼图效果示例【附demo源码】
2017/04/10 jQuery
js+html5实现复制文字按钮
2017/07/15 Javascript
基于JavaScript实现选项卡效果
2017/07/21 Javascript
node读写Excel操作实例分析
2019/11/06 Javascript
vue-property-decorator用法详解
2019/12/12 Javascript
解决Vue中使用keepAlive不缓存问题
2020/08/04 Javascript
Flask框架中密码的加盐哈希加密和验证功能的用法详解
2016/06/07 Python
Python数据可视化之画图
2019/01/15 Python
Python当中的array数组对象实例详解
2019/06/12 Python
python+opencv实现摄像头调用的方法
2019/06/22 Python
python config文件的读写操作示例
2019/09/27 Python
ipad上运行python的方法步骤
2019/10/12 Python
Python对Excel按列值筛选并拆分表格到多个文件的代码
2019/11/05 Python
python:删除离群值操作(每一行为一类数据)
2020/06/08 Python
python如何操作mysql
2020/08/17 Python
Windows环境下Python3.6.8 importError: DLLload failed:找不到指定的模块
2020/11/01 Python
使用javascript和HTML5 Canvas画的四渐变色播放按钮效果
2014/04/10 HTML / CSS
机械绘图员岗位职责
2013/11/19 职场文书
授权委托书范本
2014/04/03 职场文书
承诺书格式范文
2014/06/03 职场文书
mysql优化之query_cache_limit参数说明
2021/07/01 MySQL
在redisCluster中模糊获取key方式
2021/07/09 Redis
sql时间段切分实现每隔x分钟出一份高速门架车流量
2022/02/28 SQL Server
win10壁纸在哪个文件夹 win10桌面背景图片文件位置分享
2022/08/05 数码科技