详解PyTorch批训练及优化器比较


Posted in Python onApril 28, 2018

一、PyTorch批训练

1. 概述

PyTorch提供了一种将数据包装起来进行批训练的工具——DataLoader。使用的时候,只需要将我们的数据首先转换为torch的tensor形式,再转换成torch可以识别的Dataset格式,然后将Dataset放入DataLoader中就可以啦。

import torch 
import torch.utils.data as Data 
 
torch.manual_seed(1) # 设定随机数种子 
 
BATCH_SIZE = 5 
 
x = torch.linspace(1, 10, 10) 
y = torch.linspace(0.5, 5, 10) 
 
# 将数据转换为torch的dataset格式 
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) 
 
# 将torch_dataset置入Dataloader中 
loader = Data.DataLoader( 
  dataset=torch_dataset, 
  batch_size=BATCH_SIZE, # 批大小 
  # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少 
  shuffle=True, # 是否随机打乱顺序 
  num_workers=2, # 多线程读取数据的线程数 
  ) 
 
for epoch in range(3): 
  for step, (batch_x, batch_y) in enumerate(loader): 
    print('Epoch:', epoch, '|Step:', step, '|batch_x:', 
       batch_x.numpy(), '|batch_y', batch_y.numpy()) 
''''' 
shuffle=True 
Epoch: 0 |Step: 0 |batch_x: [ 6. 7. 2. 3. 1.] |batch_y [ 3.  3.5 1.  1.5 0.5] 
Epoch: 0 |Step: 1 |batch_x: [ 9. 10.  4.  8.  5.] |batch_y [ 4.5 5.  2.  4.  2.5] 
Epoch: 1 |Step: 0 |batch_x: [ 3.  4.  2.  9. 10.] |batch_y [ 1.5 2.  1.  4.5 5. ] 
Epoch: 1 |Step: 1 |batch_x: [ 1. 7. 8. 5. 6.] |batch_y [ 0.5 3.5 4.  2.5 3. ] 
Epoch: 2 |Step: 0 |batch_x: [ 3. 9. 2. 6. 7.] |batch_y [ 1.5 4.5 1.  3.  3.5] 
Epoch: 2 |Step: 1 |batch_x: [ 10.  4.  8.  1.  5.] |batch_y [ 5.  2.  4.  0.5 2.5] 
 
shuffle=False 
Epoch: 0 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1.  1.5 2.  2.5] 
Epoch: 0 |Step: 1 |batch_x: [ 6.  7.  8.  9. 10.] |batch_y [ 3.  3.5 4.  4.5 5. ] 
Epoch: 1 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1.  1.5 2.  2.5] 
Epoch: 1 |Step: 1 |batch_x: [ 6.  7.  8.  9. 10.] |batch_y [ 3.  3.5 4.  4.5 5. ] 
Epoch: 2 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1.  1.5 2.  2.5] 
Epoch: 2 |Step: 1 |batch_x: [ 6.  7.  8.  9. 10.] |batch_y [ 3.  3.5 4.  4.5 5. ] 
'''

2. TensorDataset

classtorch.utils.data.TensorDataset(data_tensor, target_tensor)

TensorDataset类用来将样本及其标签打包成torch的Dataset,data_tensor,和target_tensor都是tensor。

3. DataLoader

classtorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,num_workers=0, collate_fn=<function default_collate>, pin_memory=False,drop_last=False)

dataset就是Torch的Dataset格式的对象;batch_size即每批训练的样本数量,默认为;shuffle表示是否需要随机取样本;num_workers表示读取样本的线程数。

二、PyTorch的Optimizer优化器

本实验中,首先构造一组数据集,转换格式并置于DataLoader中,备用。定义一个固定结构的默认神经网络,然后为每个优化器构建一个神经网络,每个神经网络的区别仅仅是优化器不同。通过记录训练过程中的loss值,最后在图像上呈现得到各个优化器的优化过程。

代码实现:

import torch 
import torch.utils.data as Data 
import torch.nn.functional as F 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
torch.manual_seed(1) # 设定随机数种子 
 
# 定义超参数 
LR = 0.01 # 学习率 
BATCH_SIZE = 32 # 批大小 
EPOCH = 12 # 迭代次数 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1) 
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size())) 
 
#plt.scatter(x.numpy(), y.numpy()) 
#plt.show() 
 
# 将数据转换为torch的dataset格式 
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) 
# 将torch_dataset置入Dataloader中 
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, 
             shuffle=True, num_workers=2) 
 
class Net(torch.nn.Module): 
  def __init__(self): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(1, 20) 
    self.predict = torch.nn.Linear(20, 1) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
# 为每个优化器创建一个Net 
net_SGD = Net() 
net_Momentum = Net() 
net_RMSprop = Net() 
net_Adam = Net()  
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam] 
 
# 初始化优化器 
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR) 
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8) 
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9) 
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99)) 
 
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam] 
 
# 定义损失函数 
loss_function = torch.nn.MSELoss() 
losses_history = [[], [], [], []] # 记录training时不同神经网络的loss值 
 
for epoch in range(EPOCH): 
  print('Epoch:', epoch + 1, 'Training...') 
  for step, (batch_x, batch_y) in enumerate(loader): 
    b_x = Variable(batch_x) 
    b_y = Variable(batch_y) 
 
    for net, opt, l_his in zip(nets, optimizers, losses_history): 
      output = net(b_x) 
      loss = loss_function(output, b_y) 
      opt.zero_grad() 
      loss.backward() 
      opt.step() 
      l_his.append(loss.data[0]) 
 
labels = ['SGD', 'Momentum', 'RMSprop', 'Adam'] 
 
for i, l_his in enumerate(losses_history): 
  plt.plot(l_his, label=labels[i]) 
plt.legend(loc='best') 
plt.xlabel('Steps') 
plt.ylabel('Loss') 
plt.ylim((0, 0.2)) 
plt.show()

实验结果:

详解PyTorch批训练及优化器比较

由实验结果可见,SGD的优化效果是最差的,速度很慢;作为SGD的改良版本,Momentum表现就好许多;相比RMSprop和Adam的优化速度就非常好。实验中,针对不同的优化问题,比较各个优化器的效果再来决定使用哪个。

三、其他补充

1. Python的zip函数

zip函数接受任意多个(包括0个和1个)序列作为参数,返回一个tuple列表。

x = [1, 2, 3] 
y = [4, 5, 6] 
z = [7, 8, 9] 
xyz = zip(x, y, z) 
print xyz 
[(1, 4, 7), (2, 5, 8), (3, 6, 9)] 
 
x = [1, 2, 3] 
x = zip(x) 
print x 
[(1,), (2,), (3,)] 
 
x = [1, 2, 3] 
y = [4, 5, 6, 7] 
xy = zip(x, y) 
print xy 
[(1, 4), (2, 5), (3, 6)]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3.5仿微软计算器程序
Mar 30 Python
Python实现Windows和Linux之间互相传输文件(文件夹)的方法
May 08 Python
解决出现Incorrect integer value: '' for column 'id' at row 1的问题
Oct 29 Python
Python闭包之返回函数的函数用法示例
Jan 27 Python
Python进度条实时显示处理进度的示例代码
Jan 30 Python
Python 读取图片文件为矩阵和保存矩阵为图片的方法
Apr 27 Python
Python3中的json模块使用详解
May 05 Python
Django框架模板语言实例小结【变量,标签,过滤器,继承,html转义】
May 23 Python
人工神经网络算法知识点总结
Jun 11 Python
使用tensorflow实现矩阵分解方式
Feb 07 Python
python使用glob检索文件的操作
May 20 Python
python中取整数的几种方法
Nov 07 Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
Apr 28 #Python
浅谈python日志的配置文件路径问题
Apr 28 #Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 #Python
python 日志增量抓取实现方法
Apr 28 #Python
Django 使用logging打印日志的实例
Apr 28 #Python
python实现log日志的示例代码
Apr 28 #Python
Python学习笔记之open()函数打开文件路径报错问题
Apr 28 #Python
You might like
PHP脚本的10个技巧(5)
2006/10/09 PHP
基于php-fpm 参数的深入理解
2013/06/03 PHP
解决PHP超大文件下载,断点续传下载的方法详解
2013/06/06 PHP
JavaScript创建命名空间的5种写法
2014/06/24 PHP
JavaScript实现删除电脑的关机键
2016/07/26 PHP
PHP编程快速实现数组去重的方法详解
2017/07/22 PHP
表单提交验证类
2006/07/14 Javascript
比较简单的一个符合web标准的JS调用flash方法
2007/11/29 Javascript
js 数值项目的格式化函数代码
2010/05/14 Javascript
JQuery+Ajax实现数据查询、排序和分页功能
2015/09/27 Javascript
如何解决ligerUI布局时Center中的Tab高度大小
2015/11/24 Javascript
AngularJS模块学习之Anchor Scroll
2016/01/19 Javascript
JS中的eval 为什么加括号
2016/04/13 Javascript
如何用JS判断两个数字的大小
2016/07/21 Javascript
Easyui的组合框的取值与赋值
2016/10/28 Javascript
Vue.js事件处理器与表单控件绑定详解
2017/03/20 Javascript
angular2 ng2-file-upload上传示例代码
2018/08/23 Javascript
Vue.js实现可编辑的表格
2019/12/11 Javascript
js中!和!!的区别与用法
2020/05/09 Javascript
vue 防止页面加载时看到花括号的解决操作
2020/11/09 Javascript
Vue 3.0中jsx语法的使用
2020/11/13 Javascript
python获取远程图片大小和尺寸的方法
2015/03/26 Python
python控制windows剪贴板,向剪贴板中写入图片的实例
2018/05/31 Python
python3中zip()函数使用详解
2018/06/29 Python
解决Python设置函数调用超时,进程卡住的问题
2019/08/08 Python
玩转CSS3色彩
2010/01/16 HTML / CSS
HTML5 video播放器全屏(fullScreen)方法实例
2015/04/24 HTML / CSS
HTML5 canvas基本绘图之图形组合
2016/06/27 HTML / CSS
Hoover胡佛官网:美国吸尘器和洗地机品牌
2019/01/09 全球购物
追悼会上的答谢词
2014/01/10 职场文书
两年的个人工作自我评价
2014/01/10 职场文书
销售员个人求职的自我评价
2014/02/10 职场文书
小学教师师德整改措施
2014/09/29 职场文书
教师党的群众路线教育实践活动剖析材料
2014/10/09 职场文书
VUE之图片Base64编码使用ElementUI组件上传
2022/04/09 Vue.js
优化Mysql查询的示例
2022/04/26 MySQL