关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现从百度API获取天气的方法
Mar 11 Python
Python实现的密码强度检测器示例
Aug 23 Python
python中验证码连通域分割的方法详解
Jun 04 Python
python 用opencv调用训练好的模型进行识别的方法
Dec 07 Python
解决python2 绘图title,xlabel,ylabel出现中文乱码的问题
Jan 29 Python
Python中判断子串存在的性能比较及分析总结
Jun 23 Python
pandas取出重复数据的方法
Jul 04 Python
在VS2017中用C#调用python脚本的实现
Jul 31 Python
python多线程实现同时执行两个while循环的操作
May 02 Python
Softmax函数原理及Python实现过程解析
May 22 Python
Python实现Appium端口检测与释放的实现
Dec 31 Python
python将图片转为矢量图的方法步骤
Mar 30 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
《OVERLORD》第四季,终于等到你!
2020/03/02 日漫
php比较两个绝对时间的大小
2014/01/31 PHP
Yii2中添加全局函数的方法分析
2017/05/04 PHP
浅谈PHP封装CURL
2019/03/06 PHP
PHP 命名空间和自动加载原理与用法实例分析
2020/04/29 PHP
Aster vs Newbee BO3 第三场2.18
2021/03/10 DOTA
腾讯UED 漂亮的提示信息效果代码
2011/09/12 Javascript
JQuery模板插件 jquery.tmpl 动态ajax扩展
2011/11/10 Javascript
javascript事件模型实例分析
2015/01/30 Javascript
js日期范围初始化得到前一个月日期的方法
2015/05/05 Javascript
javascript实现可全选、反选及删除表格的方法
2015/05/15 Javascript
jQuery实现平滑滚动页面到指定锚点链接的方法
2015/07/15 Javascript
JS留言功能的简单实现案例(推荐)
2016/06/23 Javascript
JavaScript基础知识点归纳(推荐)
2016/07/09 Javascript
纯js实现动态时间显示
2020/09/07 Javascript
Node.js 8 中的重要新特性
2017/06/28 Javascript
解决vue-router进行build无法正常显示路由页面的问题
2018/03/06 Javascript
vue实现codemirror代码编辑器中的SQL代码格式化功能
2019/08/27 Javascript
云服务器部署Node.js项目的方法步骤(小白系列)
2020/03/23 Javascript
[00:37]食人魔魔法师轮盘吉兆顺应全新至宝将拥有额外款式
2019/12/19 DOTA
Python读取键盘输入的2种方法
2015/06/16 Python
Python实现的knn算法示例
2018/06/14 Python
matplotlib给子图添加图例的方法
2018/08/03 Python
配置 Pycharm 默认 Test runner 的图文教程
2018/11/30 Python
如何基于Python实现数字类型转换
2020/02/07 Python
pyqt5中动画的使用详解
2020/04/01 Python
pandas实现导出数据的四种方式
2020/12/13 Python
日本航空官方网站:JAL
2019/06/19 全球购物
瑞典网上购买现代和复古家具:Reforma
2019/10/21 全球购物
Java程序员面试题
2013/07/15 面试题
师范生自荐信范文
2013/10/06 职场文书
会计应聘求职信范文
2013/12/17 职场文书
酒店保安领班职务说明书
2014/03/04 职场文书
2015年乡镇统计工作总结
2015/04/22 职场文书
安全生产奖惩制度
2015/08/06 职场文书
Maven学习----Maven安装与环境变量配置教程
2021/06/29 Java/Android