关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 抓取动态网页内容方案详解
Dec 25 Python
Python实现简单的文件传输与MySQL备份的脚本分享
Jan 03 Python
Python极简代码实现杨辉三角示例代码
Nov 15 Python
pycharm: 恢复(reset) 误删文件的方法
Oct 22 Python
使用Python+wxpy 找出微信里把你删除的好友实例
Feb 21 Python
浅析Python3中的对象垃圾收集机制
Jun 06 Python
python扫描线填充算法详解
Feb 19 Python
python数据类型可变不可变知识点总结
Mar 06 Python
Python 实现使用空值进行赋值 None
Mar 12 Python
关于tf.matmul() 和tf.multiply() 的区别说明
Jun 18 Python
python为什么会环境变量设置不成功
Jun 23 Python
Python 利用flask搭建一个共享服务器的步骤
Dec 05 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
PHP5.2下chunk_split()函数整数溢出漏洞 分析
2007/06/06 PHP
php教程之魔术方法的使用示例(php魔术函数)
2014/02/12 PHP
浅析php中json_encode()和json_decode()
2014/05/25 PHP
ThinkPHP3.1新特性之动态设置自动完成和自动验证示例
2014/06/19 PHP
Yii2框架实现登录、退出及自动登录功能的方法详解
2017/10/24 PHP
PHP序列化的四种实现方法与横向对比
2018/11/29 PHP
PHP实现与java 通信的插件使用教程
2019/08/11 PHP
用 JavaScript 迁移目录
2006/12/18 Javascript
HTA版JSMin(省略修饰语若干)基于javascript语言编写
2009/12/24 Javascript
JavaScript 学习笔记(十六) js事件
2010/02/01 Javascript
js浏览器本地存储store.js介绍及应用
2014/05/13 Javascript
浅谈重写window对象的方法
2014/12/29 Javascript
JavaScript实现将数组数据添加到Select下拉框的方法
2015/08/21 Javascript
微信开发 微信授权详解
2016/10/21 Javascript
Vue+SpringBoot开发V部落博客管理平台
2017/12/27 Javascript
vue-rx的初步使用教程
2018/09/21 Javascript
基于Angular中ng-controller父子级嵌套的相关属性详解
2018/10/08 Javascript
VUE+Element UI实现简单的表格行内编辑效果的示例的代码
2018/10/31 Javascript
Vue 页面状态保持页面间数据传输的一种方法(推荐)
2018/11/01 Javascript
实例讲解v-if和v-show的区别
2019/01/31 Javascript
Node.js系列之安装配置与基本使用(1)
2019/08/30 Javascript
javascript操作元素的常见方法小结
2019/11/13 Javascript
在webstorm中配置less的方法详解
2020/09/25 Javascript
Python函数式编程指南(一):函数式编程概述
2015/06/24 Python
Python实现向服务器请求压缩数据及解压缩数据的方法示例
2017/06/09 Python
Python学习笔记之For循环用法详解
2019/08/14 Python
Python接口测试数据库封装实现原理
2020/05/09 Python
html5 input输入实时检测以及延时优化
2018/07/18 HTML / CSS
有机童装:Toby Tiger
2018/05/23 全球购物
laravel使用redis队列实例讲解
2021/03/23 PHP
视光学专业毕业生推荐信
2013/10/28 职场文书
告诉你怎样写创业计划书
2014/01/27 职场文书
幼儿园小班评语大全
2014/04/17 职场文书
2017春节晚会开幕词
2016/03/03 职场文书
多人股份制合作协议书
2016/03/19 职场文书
python第三方网页解析器 lxml 扩展库与 xpath 的使用方法
2021/04/06 Python