pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
linux平台使用Python制作BT种子并获取BT种子信息的方法
Jan 20 Python
python3 cvs将数据读取为字典的方法
Dec 22 Python
python 计算数据偏差和峰度的方法
Jun 29 Python
浅析Python语言自带的数据结构有哪些
Aug 27 Python
python 实现一个反向单位矩阵示例
Nov 29 Python
python 定义类时,实现内部方法的互相调用
Dec 25 Python
Python 实现数组相减示例
Dec 27 Python
解决pycharm中opencv-python导入cv2后无法自动补全的问题(不用作任何文件上的修改)
Mar 05 Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 Python
matlab、python中矩阵的互相导入导出方式
Jun 01 Python
彻底解决Python包下载慢问题
Nov 15 Python
python的列表生成式,生成器和generator对象你了解吗
Mar 16 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
php 三元运算符实例详细介绍
2016/12/15 PHP
PHP Swoole异步读取、写入文件操作示例
2019/10/24 PHP
laravel框架使用FormRequest进行表单验证,验证异常返回JSON操作示例
2020/02/18 PHP
防止网站内容被拷贝的一些方法与优缺点好处与坏处分析
2007/11/30 Javascript
JavaScript 面向对象编程(2) 定义类
2010/05/18 Javascript
基于Jquery的标签智能验证实现代码
2010/12/27 Javascript
Jquery实现自定义弹窗示例
2014/03/12 Javascript
jquery序列化表单去除指定元素示例代码
2014/04/10 Javascript
使用js画图之饼图
2015/01/12 Javascript
Bootstrap学习笔记之css样式设计(1)
2016/06/07 Javascript
表单中单选框添加选项和移除选项
2016/07/04 Javascript
JavaScript随机打乱数组顺序之随机洗牌算法
2016/08/02 Javascript
JavaScript队列、优先队列与循环队列
2016/11/14 Javascript
jQuery实现动态删除LI的方法
2017/05/30 jQuery
微信小程序仿朋友圈发布动态功能
2018/07/15 Javascript
vue文件运行的方法教学
2019/02/12 Javascript
layui 监听select选择 获取当前select的ID名称方法
2019/09/24 Javascript
详解Vite的新体验
2021/02/22 Javascript
[32:26]EG vs IG 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
python获得linux下所有挂载点(mount points)的方法
2015/04/29 Python
python检测是文件还是目录的方法
2015/07/03 Python
Python批量查询域名是否被注册过
2017/06/21 Python
Python多线程原理与用法详解
2018/08/20 Python
CentOS 7下安装Python3.6 及遇到的问题小结
2018/11/08 Python
python 删除字符串中连续多个空格并保留一个的方法
2018/12/22 Python
Python递归函数实例讲解
2019/02/27 Python
初学者学习Python好还是Java好
2020/05/26 Python
您的网上新华书店:文轩网
2016/08/24 全球购物
电子商务专业学生的学习自我评价
2013/10/27 职场文书
计算机专业毕业生自荐信
2013/12/31 职场文书
情人节寄语大全
2014/04/11 职场文书
班长竞选演讲稿
2014/04/24 职场文书
监察建议书格式
2014/05/19 职场文书
工作表扬信范文
2015/01/17 职场文书
关于运动会的宣传稿
2015/07/23 职场文书
python使用torch随机初始化参数
2022/03/22 Python