关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
初学python数组的处理代码
Jan 04 Python
深度剖析使用python抓取网页正文的源码
Jun 11 Python
python批量修改文件名的实现代码
Sep 01 Python
Python THREADING模块中的JOIN()方法深入理解
Feb 18 Python
关于Python中浮点数精度处理的技巧总结
Aug 10 Python
Python实现压缩和解压缩ZIP文件的方法分析
Sep 28 Python
对pandas里的loc并列条件索引的实例讲解
Nov 15 Python
windows下搭建python scrapy爬虫框架步骤
Dec 23 Python
python 动态生成变量名以及动态获取变量的变量名方法
Jan 20 Python
python 默认参数相关知识详解
Sep 18 Python
python 中的paramiko模块简介及安装过程
Feb 29 Python
python多线程爬取西刺代理的示例代码
Jan 30 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
PHP中simplexml_load_string函数使用说明
2011/01/01 PHP
php循环语句 for()与foreach()用法区别介绍
2012/09/05 PHP
Yii2单元测试用法示例
2016/11/12 PHP
基于jQuery的动态增删改查表格信息,可左键/右键提示(原创自Zjmainstay)
2012/07/31 Javascript
JavaScript对象和字串之间的转换实例探讨
2013/04/21 Javascript
JQuery对表单元素的基本操作使用总结
2014/07/18 Javascript
javascript数组遍历for与for in区别详解
2014/12/04 Javascript
javascript中DOM复选框选择用法实例
2015/05/14 Javascript
js实现数组转换成json
2015/06/26 Javascript
JavaScript的Polymer框架中dom-repeat与VM的相关操作
2015/07/29 Javascript
为jQuery-easyui的tab组件添加右键菜单功能的简单实例
2016/10/10 Javascript
JS实现针对给定时间的倒计时功能示例
2017/04/11 Javascript
Node.js如何实现注册邮箱激活功能 (常见)
2017/07/23 Javascript
详解小程序如何动态绑定点击的执行方法
2019/11/26 Javascript
vue-router的hooks用法详解
2020/06/08 Javascript
Vue.js使用axios动态获取response里的data数据操作
2020/09/08 Javascript
Vertx基于EventBus发送接受自定义对象
2020/11/16 Javascript
[00:34]DOTA2上海特级锦标赛 VG战队宣传片
2016/03/04 DOTA
[01:00:22]DOTA2-DPC中国联赛定级赛 LBZS vs Magma BO3第三场 1月10日
2021/03/11 DOTA
Python enumerate遍历数组示例应用
2008/09/06 Python
python3实现ftp服务功能(客户端)
2017/03/24 Python
Django中数据库的数据关系:一对一,一对多,多对多
2018/10/21 Python
浅谈pyqt5中信号与槽的认识
2019/02/17 Python
Python中dict和set的用法讲解
2019/03/28 Python
Python列表删除元素del、pop()和remove()的区别小结
2019/09/11 Python
python线程信号量semaphore使用解析
2019/11/30 Python
解决python-docx打包之后找不到default.docx的问题
2020/02/13 Python
CSS3色彩模式有哪些?CSS3 HSL色彩模式的定义
2016/04/26 HTML / CSS
HTML5中drawImage用法分析
2014/12/01 HTML / CSS
英国女士家居服网站:hush
2017/08/09 全球购物
俄罗斯美容和健康网上商店:Созвездие Красоты
2019/07/23 全球购物
好的自荐信包括什么内容
2013/11/07 职场文书
机械设计及其自动化专业求职信
2014/06/09 职场文书
领导工作表现评语
2015/01/04 职场文书
使用 Apache 反向代理的设置技巧
2022/01/18 Servers
Vue组件化(ref,props, mixin,.插件)详解
2022/05/15 Vue.js