pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之获取本机ip数据包示例
Feb 10 Python
python基础教程之数字处理(math)模块详解
Mar 25 Python
python人人网登录应用实例
Sep 26 Python
Python标准异常和异常处理详解
Feb 02 Python
Django+Xadmin构建项目的方法步骤
Mar 06 Python
python3利用Socket实现通信的方法示例
May 06 Python
Python循环实现n的全排列功能
Sep 16 Python
python 字段拆分详解
Dec 17 Python
Python中的 ansible 动态Inventory 脚本
Jan 19 Python
浅谈Python3多线程之间的执行顺序问题
May 02 Python
如何理解Python中的变量
Jun 01 Python
Python JSON常用编解码方法代码实例
Sep 05 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
php上传、管理照片示例
2006/10/09 PHP
php include的妙用,实现路径加密
2008/07/29 PHP
PHP获取文件扩展名的4种方法
2015/11/24 PHP
PHP基于mcript扩展实现对称加密功能示例
2019/02/21 PHP
TNC vs BOOM BO3 第二场2.13
2021/03/10 DOTA
javascript英文日期(有时间)选择器
2007/05/02 Javascript
初窥JQuery(二) 事件机制(1)
2010/11/25 Javascript
JavaScript Title、alt提示(Tips)实现源码解读
2010/12/12 Javascript
选择器中含有空格在使用示例及注意事项
2013/07/31 Javascript
jquery实现div拖拽宽度示例代码
2013/07/31 Javascript
JavaScript弹出窗口方法汇总
2014/08/12 Javascript
Node.js 去掉种子(torrent)文件里的邪恶信息
2015/03/27 Javascript
Javascript仿新浪游戏频道鼠标悬停显示子菜单效果
2015/08/21 Javascript
BootStrap的JS插件之轮播效果案例详解
2016/05/16 Javascript
JavaScript中ES6 Babel正确安装过程
2016/07/18 Javascript
详解vue-cli 接口代理配置
2017/12/13 Javascript
jquery.onoff实现简单的开关按钮功能(推荐)
2018/05/24 jQuery
从vue源码看props的用法
2019/01/09 Javascript
Vue使用axios出现options请求方法
2019/05/30 Javascript
vue vant Area组件使用详解
2019/12/09 Javascript
Python实现的HTTP并发测试完整示例
2020/04/23 Python
利用python在大量数据文件下删除某一行的例子
2019/08/21 Python
Python包,__init__.py功能与用法分析
2020/01/07 Python
Python爬取12306车次信息代码详解
2020/08/12 Python
Nike荷兰官方网站:Nike.com (NL)
2018/04/19 全球购物
大学生全国两会报告感想
2014/03/17 职场文书
初中家长寄语
2014/04/02 职场文书
《狼和小羊》教学反思
2014/04/20 职场文书
省文明单位申报材料
2014/05/08 职场文书
老干部工作先进集体事迹材料
2014/05/21 职场文书
办公用房租赁协议书
2014/11/29 职场文书
2014年保卫科工作总结
2014/12/05 职场文书
颐和园英文导游词
2015/01/30 职场文书
2015年学校工作总结范文
2015/04/20 职场文书
手把手教你用SpringBoot将文件打包成zip存放或导出
2021/06/11 Java/Android
简单聊聊Golang中defer预计算参数
2022/03/25 Golang