关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python sort、sorted高级排序技巧
Nov 21 Python
Python查找函数f(x)=0根的解决方法
May 07 Python
Python的组合模式与责任链模式编程示例
Feb 02 Python
浅谈python中的面向对象和类的基本语法
Jun 13 Python
python简单实现获取当前时间
Aug 27 Python
python DataFrame 取差集实例
Jan 30 Python
python3正则提取字符串里的中文实例
Jan 31 Python
python计算无向图节点度的实例代码
Nov 22 Python
Python生成器常见问题及解决方案
Mar 21 Python
python 模拟登陆github的示例
Dec 04 Python
python中翻译功能translate模块实现方法
Dec 17 Python
如何用Python提取10000份log中的产品信息
Jan 14 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
PHP批量删除、清除UTF-8文件BOM头的代码实例
2014/04/14 PHP
PHP+Apache+Mysql环境搭建教程
2016/08/01 PHP
php版交通银行网银支付接口开发入门教程
2016/09/26 PHP
PHP中的自动加载操作实现方法详解
2019/08/06 PHP
解决PhpStorm64不能启动的问题
2020/06/20 PHP
Wordpress ThickBox 点击图片显示下一张图的修改方法
2010/12/11 Javascript
如何制作浮动广告 JavaScript制作浮动广告代码
2012/12/30 Javascript
向当前style sheet中插入一个新的style实现方法
2013/04/01 Javascript
EXTjs4.0的store的findRecord的BUG演示代码
2013/06/08 Javascript
replace()方法查找字符使用示例
2013/10/28 Javascript
js换图片效果可进行定时操作
2014/06/09 Javascript
js中实现多态采用和继承类似的方法
2014/08/22 Javascript
nodejs修复ipa处理过的png图片
2016/02/17 NodeJs
jquery实现简单的banner轮播效果【实例】
2016/03/30 Javascript
微信小程序 教程之小程序配置
2016/10/17 Javascript
jQuery向webApi提交post json数据
2017/01/16 Javascript
JavaScript实现简单的文本逐字打印效果示例
2018/04/12 Javascript
Javascript读写cookie的实例源码
2019/03/16 Javascript
浅谈vue加载优化策略
2019/03/19 Javascript
如何用原生js写一个弹窗消息提醒插件
2019/05/24 Javascript
Vue实现拖放排序功能的实例代码
2019/07/08 Javascript
使用Webpack 搭建 Vue3 开发环境过程详解
2020/07/28 Javascript
vue实现日历表格(element-ui)
2020/09/24 Javascript
[46:12]完美世界DOTA2联赛循环赛 DM vs Matador BO2第一场 11.04
2020/11/04 DOTA
Python中使用PIL库实现图片高斯模糊实例
2015/02/08 Python
Python3.5面向对象与继承图文实例详解
2019/04/24 Python
用pytorch的nn.Module构造简单全链接层实例
2020/01/14 Python
python爬虫实例之获取动漫截图
2020/05/31 Python
奥林匹亚体育:Olympia Sports
2020/12/30 全球购物
解释i节点在文件系统中的作用
2013/11/26 面试题
Ruby中的保护方法和私有方法与一般面向对象程序设计语言的一样吗
2013/05/01 面试题
奶茶店创业计划书范文
2014/01/17 职场文书
工程售后服务承诺书
2014/05/21 职场文书
政工师工作总结2015
2015/05/26 职场文书
提升Nginx性能的一些建议
2021/03/31 Servers
Redis数据同步之redis shake的实现方法
2022/04/21 Redis