关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python发腾讯微博代码分享
Jan 10 Python
在Python中封装GObject模块进行图形化程序编程的教程
Apr 14 Python
python 远程统计文件代码分享
May 14 Python
Python2.7基于淘宝接口获取IP地址所在地理位置的方法【测试可用】
Jun 07 Python
python+VTK环境搭建及第一个简单程序代码
Dec 13 Python
Python和Java进行DES加密和解密的实例
Jan 09 Python
TensorFlow高效读取数据的方法示例
Feb 06 Python
对Python3+gdal 读取tiff格式数据的实例讲解
Dec 04 Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 Python
Python如何使用OS模块调用cmd
Feb 27 Python
利用 Python ElementTree 生成 xml的实例
Mar 06 Python
python将图片转为矢量图的方法步骤
Mar 30 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
php HandlerSocket的使用
2011/05/02 PHP
php判断ip黑名单程序代码实例
2014/02/24 PHP
PHP使用Memcache时模拟命名空间及缓存失效问题的解决
2016/02/27 PHP
Thinkphp 框架扩展之Widget扩展实现方法分析
2020/04/23 PHP
在textarea中显示html页面的javascript代码
2007/04/20 Javascript
JSON 学习之JSON in JavaScript详细使用说明
2010/02/23 Javascript
Jquery仿淘宝京东多条件筛选可自行结合ajax加载示例
2013/08/28 Javascript
简单实用的反馈表单无刷新提交带验证
2013/11/15 Javascript
删除javascript中注释语句的正则表达式
2014/06/11 Javascript
Bootstrap每天必学之面板
2015/11/30 Javascript
jquery实现具有收缩功能的垂直导航菜单
2016/02/16 Javascript
Angular中的interceptors拦截器
2017/06/25 Javascript
AngularJs每天学习之总体介绍
2017/08/07 Javascript
JavaScript实现的前端AES加密解密功能【基于CryptoJS】
2018/08/28 Javascript
微信小程序下拉框功能的实例代码
2018/11/06 Javascript
Vue使用lodop实现打印小结
2019/07/06 Javascript
解决Vue使用bus总线时,第一次路由跳转时数据没成功传递问题
2020/07/28 Javascript
python读写json文件的简单实现
2017/04/11 Python
Python探索之创建二叉树
2017/10/25 Python
python实现定时自动备份文件到其他主机的实例代码
2018/02/23 Python
python爬虫框架scrapy实现模拟登录操作示例
2018/08/02 Python
python将数组n等分的实例
2019/12/02 Python
python加载自定义词典实例
2019/12/06 Python
Django通用类视图实现忘记密码重置密码功能示例
2019/12/17 Python
python selenium xpath定位操作
2020/09/01 Python
使用python操作lmdb对数据读取的实例
2020/12/11 Python
用CSS3打造HTML5的Logo(实现代码)
2016/06/16 HTML / CSS
css3翻牌翻数字的示例代码
2020/02/07 HTML / CSS
达拉斯牛仔官方商店:Dallas Cowboys Pro Shop
2018/02/10 全球购物
美国社交购物市场:MassGenie
2019/02/18 全球购物
自主招生自荐信格式
2013/12/03 职场文书
电大本科自我鉴定
2014/02/05 职场文书
现场施工员岗位职责
2014/03/10 职场文书
2014年防汛工作总结
2014/12/08 职场文书
骨干教师事迹材料
2014/12/17 职场文书
导游词之吉林花园山
2019/10/17 职场文书