pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的框架中为MySQL实现restful接口的教程
Apr 08 Python
python使用os.listdir和os.walk获得文件的路径的方法
Dec 16 Python
一个简单的python爬虫程序 爬取豆瓣热度Top100以内的电影信息
Apr 17 Python
浅谈pycharm的xmx和xms设置方法
Dec 03 Python
Python操作rabbitMQ的示例代码
Mar 19 Python
Python帮你微信头像任意添加装饰别再@微信官方了
Sep 25 Python
python中利用matplotlib读取灰度图的例子
Dec 07 Python
Python阶乘求和的代码详解
Feb 14 Python
Python如何向SQLServer存储二进制图片
Jun 08 Python
Keras 中Leaky ReLU等高级激活函数的用法
Jul 05 Python
python实现网页录音效果
Oct 26 Python
Python字典实现伪切片功能
Oct 28 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
博士208HAF收音机实习报告
2021/03/02 无线电
细谈php中SQL注入攻击与XSS攻击
2012/06/10 PHP
php 伪静态之IIS篇
2014/06/02 PHP
分享PHP计算两个日期相差天数的代码
2015/12/23 PHP
登陆成功后自动计算秒数执行跳转
2014/01/23 Javascript
javascript中setTimeout和setInterval的unref()和ref()用法示例
2014/11/26 Javascript
jsMind通过鼠标拖拽的方式调整节点位置
2015/04/13 Javascript
ionic js 复选框 与普通的 HTML 复选框到底有没区别
2016/06/06 Javascript
Laravel中常见的错误与解决方法小结
2016/08/30 Javascript
JS之获取样式的简单实现方法(推荐)
2016/09/13 Javascript
JS实现的简单拖拽购物车功能示例【附源码下载】
2018/01/03 Javascript
vue使用iframe嵌入网页的示例代码
2020/06/09 Javascript
解决layui table表单提示数据接口请求异常的问题
2019/09/24 Javascript
Vue2.0 实现页面缓存和不缓存的方式
2019/11/12 Javascript
Element Notification通知的实现示例
2020/07/27 Javascript
[05:35]DOTA2英雄梦之声_第13期_拉比克
2014/06/21 DOTA
给Python初学者的一些编程技巧
2015/04/03 Python
在Python的Django框架中生成CSV文件的方法
2015/07/22 Python
简单谈谈Python中函数的可变参数
2016/09/02 Python
Python使用zip合并相邻列表项的方法示例
2018/03/17 Python
Centos 升级到python3后pip 无法使用的解决方法
2018/06/12 Python
浅谈python3中input输入的使用
2019/08/02 Python
python 模拟登陆163邮箱
2020/12/15 Python
什么是动态端口(Dynamic Ports)?动态端口的范围是多少?
2014/12/12 面试题
后勤人员自我鉴定
2013/10/20 职场文书
日语专业个人的求职信
2013/12/03 职场文书
大学生学习2014年全国两会心得体会
2014/03/12 职场文书
大学生求职信范文
2014/05/24 职场文书
2014年征兵标语
2014/06/20 职场文书
党在我心中演讲稿
2014/09/02 职场文书
2015年百日安全活动总结
2015/03/26 职场文书
青年志愿者活动感想
2015/08/07 职场文书
党员干部学习心得体会
2016/01/23 职场文书
Python可视化学习之matplotlib内置单颜色
2022/02/24 Python
MySQL Server层四个日志的实现
2022/03/31 MySQL
铁头也玩根德 YachtBoy YB-230......
2022/04/05 无线电