关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python简单删除目录下文件以及文件夹的方法
May 27 Python
python2.7和NLTK安装详细教程
Sep 19 Python
python使用Matplotlib画条形图
Mar 25 Python
Python+OpenCV图片局部区域像素值处理详解
Jan 23 Python
Python 运行.py文件和交互式运行代码的区别详解
Jul 02 Python
python 函数的缺省参数使用注意事项分析
Sep 17 Python
python文字转语音实现过程解析
Nov 12 Python
Python打包工具PyInstaller的安装与pycharm配置支持PyInstaller详细方法
Feb 27 Python
python实现用户名密码校验
Mar 18 Python
使用pth文件添加Python环境变量方式
May 26 Python
如何在python中处理配置文件代码实例
Sep 27 Python
Python实现Excel自动分组合并单元格
Feb 22 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
一个php作的文本留言本的例子(二)
2006/10/09 PHP
PHP获取youku视频真实flv文件地址的方法
2014/12/23 PHP
php利用smtp类实现电子邮件发送
2015/10/30 PHP
WordPress中调试缩略图的相关PHP函数使用解析
2016/01/07 PHP
javascript延时重复执行函数 lLoopRun.js
2007/06/29 Javascript
jQuery选择器简明总结(含用法实例,一目了然)
2014/04/25 Javascript
JavaScript strike方法入门实例(给字符串加上删除线)
2014/10/17 Javascript
JavaScript和HTML DOM的区别与联系及Javascript和DOM的关系
2015/11/15 Javascript
js 点击a标签 获取a的自定义属性方法
2016/11/21 Javascript
js实现HashTable(哈希表)的实例分析
2016/11/21 Javascript
Angular2 Service实现简单音乐播放器服务
2017/02/24 Javascript
JS闭包用法实例分析
2017/03/27 Javascript
jQuery中的deferred对象和extend方法详解
2017/05/08 jQuery
详细AngularJs4的图片剪裁组件的实例
2017/07/12 Javascript
Vue2.0基于vue-cli+webpack父子组件通信(实例讲解)
2017/09/14 Javascript
AngularJs ng-change事件/指令的用法小结
2017/11/01 Javascript
Vue-cli@3.0 插件系统简析
2018/09/05 Javascript
vue router 源码概览案例分析
2018/10/09 Javascript
Vue中 v-if/v-show/插值表达式导致闪现的原因及解决办法
2018/10/12 Javascript
基于js实现复制内容到操作系统粘贴板过程解析
2019/10/11 Javascript
茶余饭后聊聊Vue3.0响应式数据那些事儿
2019/10/30 Javascript
vue中的v-model原理,与组件自定义v-model详解
2020/08/04 Javascript
django文档学习之applications使用详解
2018/01/29 Python
用TensorFlow实现戴明回归算法的示例
2018/05/02 Python
对Python中list的倒序索引和切片实例讲解
2018/11/15 Python
Python中GeoJson和bokeh-1的使用讲解
2019/01/03 Python
利用pyshp包给shapefile文件添加字段的实例
2019/12/06 Python
TensorFlow的环境配置与安装方法
2021/02/20 Python
Everlast官网:拳击、综合格斗和健身相关的体育用品
2020/08/03 全球购物
摄影专业毕业生求职信
2014/03/13 职场文书
青蓝工程实施方案
2014/03/27 职场文书
乡文化站暑期培训方案
2014/08/28 职场文书
离婚协议书范文2014(夫妻感情破裂)
2014/12/14 职场文书
邀请函模板
2015/02/02 职场文书
创业计划书详解
2019/07/19 职场文书
2019送给家人们的中秋节祝福语
2019/08/15 职场文书