关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用gensim计算文档相似性
Apr 10 Python
Python 爬虫学习笔记之多线程爬虫
Sep 21 Python
使用python实现ANN
Dec 20 Python
python 对dataframe下面的值进行大规模赋值方法
Jun 09 Python
python爬虫超时的处理的实例
Dec 19 Python
python re库的正则表达式入门学习教程
Mar 08 Python
Pycharm远程调试原理及具体配置详解
Aug 08 Python
python__name__原理及用法详解
Nov 02 Python
python对Excel的读取的示例代码
Feb 14 Python
python变量的作用域是什么
May 26 Python
Jupyter Notebook内使用argparse报错的解决方案
Jun 03 Python
Python3中最常用的5种线程锁实例总结
Jul 07 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
PHP实现MySQL更新记录的代码
2008/06/07 PHP
PHP取二进制文件头快速判断文件类型的实现代码
2013/08/05 PHP
PHP的Yii框架中移除组件所绑定的行为的方法
2016/03/18 PHP
PHP PDOStatement::bindValue讲解
2019/01/30 PHP
js前台判断开始时间是否小于结束时间
2012/02/23 Javascript
Ajax搜索结果页面下方的分页按钮的生成
2012/04/05 Javascript
jQuery拖拽 & 弹出层 介绍与示例
2013/12/27 Javascript
JavaScript function 的 length 属性使用介绍
2014/09/15 Javascript
jquery插件hiAlert实现网页对话框美化
2015/05/03 Javascript
jQuery+jsp实现省市县三级联动效果(附源码)
2015/12/03 Javascript
基于Node.js的强大爬虫 能直接发布抓取的文章哦
2016/01/10 Javascript
PassWord输入框代码分享
2016/06/07 Javascript
nodejs入门教程五:连接数据库的方法分析
2017/04/24 NodeJs
详解nodejs实现本地上传图片并预览功能(express4.0+)
2017/06/28 NodeJs
基于vue-video-player自定义播放器的方法
2018/03/21 Javascript
Layui动态生成select下拉选择框不显示的解决方法
2019/09/24 Javascript
vue.config.js常用配置详解
2019/11/14 Javascript
vue动态设置页面title的方法实例
2020/08/23 Javascript
vue 导航守卫和axios拦截器有哪些区别
2020/12/19 Vue.js
在HTML中使用JavaScript的两种方法
2020/12/24 Javascript
[01:03:22]LGD vs OG 2018国际邀请赛淘汰赛BO3 第一场 8.25
2018/08/29 DOTA
python访问sqlserver示例
2014/02/10 Python
python 用正则表达式筛选文本信息的实例
2018/06/05 Python
Python爬虫实现验证码登录代码实例
2019/05/10 Python
Python3多目标赋值及共享引用注意事项
2019/05/27 Python
Python实现网页截图(PyQT5)过程解析
2019/08/12 Python
python中web框架的自定义创建
2019/09/08 Python
马克华菲官方商城:Mark Fairwhale
2016/09/04 全球购物
雅诗兰黛(Estee Lauder)英国官方网站:世界顶级化妆品牌
2016/12/29 全球购物
高中生校园生活自我评价
2013/09/19 职场文书
医学生自荐信
2013/12/03 职场文书
开办大学饮食联盟创业计划书
2014/01/29 职场文书
高中教师个人工作总结
2015/02/10 职场文书
openstack中的rpc远程调用的方法
2021/07/09 Python
python中mongodb包操作数据库
2022/04/19 Python
微软官方消息,在 2023 年 4 月 11 日之后微软将不再为 Office 2013 和 Skype for Business 2015 提供安全更新
2022/04/21 数码科技