pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中字典的基本知识初步介绍
May 21 Python
Django教程笔记之中间件middleware详解
Aug 01 Python
Python动态赋值的陷阱知识点总结
Mar 17 Python
简单了解Python生成器是什么
Jul 02 Python
Python企业编码生成系统总体系统设计概述
Jul 26 Python
python利用tkinter实现屏保
Jul 30 Python
Django框架HttpResponse对象用法实例分析
Nov 01 Python
python 使用递归实现打印一个数字的每一位示例
Feb 27 Python
python脚本实现mp4中的音频提取并保存在原目录
Feb 27 Python
Python多线程threading创建及使用方法解析
Jun 17 Python
基于CentOS搭建Python Django环境过程解析
Aug 24 Python
Python的flask接收前台的ajax的post数据和get数据的方法
Apr 12 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
受疫情影响 动画《Re从零开始的异世界生活》第二季延期至7月
2020/03/10 日漫
php csv操作类代码
2009/12/14 PHP
使用GROUP BY的时候如何统计记录条数 COUNT(*) DISTINCT
2011/04/23 PHP
PHP+Mysql+jQuery查询和列表框选择操作实例讲解
2015/10/22 PHP
CodeIgniter配置之autoload.php自动加载用法分析
2016/01/20 PHP
PHP实现搜索时记住状态的方法示例
2018/05/11 PHP
Windows Live的@live.com域名注册漏洞 利用代码
2006/12/27 Javascript
Extjs中常用表单介绍与应用
2010/06/07 Javascript
jQuery Tools tab使用介绍
2012/07/14 Javascript
JS自定义对象实现Java中Map对象功能的方法
2015/01/20 Javascript
Jquery简单分页实现方法
2015/07/24 Javascript
js实现人才网站职位选择功能的方法
2015/08/14 Javascript
jquery获取复选框checkbox的值的简单实现方法
2016/05/26 Javascript
js实现做通讯录的索引滑动显示效果和滑动显示锚点效果
2017/02/18 Javascript
JS小球抛物线轨迹运动的两种实现方法详解
2017/12/20 Javascript
用Python进行基础的函数式编程的教程
2015/03/31 Python
Fiddler如何抓取手机APP数据包
2016/01/22 Python
Python 的描述符 descriptor详解
2016/02/27 Python
python 中如何获取列表的索引
2019/07/02 Python
Django框架视图层URL映射与反向解析实例分析
2019/07/29 Python
Python 实现微信自动回复的方法
2020/09/11 Python
纯CSS3实现滚动的齿轮动画效果
2014/06/05 HTML / CSS
HTML5 script元素async、defer异步加载使用介绍
2013/08/23 HTML / CSS
戴尔英国官网:Dell英国
2017/05/27 全球购物
美国优质宠物用品购买网站:Muttropolis
2020/02/17 全球购物
酒店led欢迎词
2014/01/09 职场文书
境外导游求职信
2014/02/27 职场文书
信用社主任竞聘演讲稿
2014/05/23 职场文书
优秀中职教师事迹材料
2014/08/26 职场文书
高中生学习计划书
2014/09/15 职场文书
领导干部群众路线教育实践活动剖析材料
2014/10/10 职场文书
2014年世界艾滋病日演讲稿
2014/11/28 职场文书
初中教师德育工作总结2015
2015/05/12 职场文书
歼十出击观后感
2015/06/11 职场文书
React如何创建组件
2021/06/27 Javascript
Python机器学习应用之基于线性判别模型的分类篇详解
2022/01/18 Python