关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python发送邮件的实例代码(支持html、图片、附件)
Mar 04 Python
基于wxpython实现的windows GUI程序实例
May 30 Python
Python简单获取自身外网IP的方法
Sep 18 Python
python+pillow绘制矩阵盖尔圆简单实例
Jan 16 Python
解决django后台样式丢失,css资源加载失败的问题
Jun 11 Python
pandas 选取行和列数据的方法详解
Aug 08 Python
Python实现打印实心和空心菱形
Nov 23 Python
python实现超级马里奥
Mar 18 Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 Python
Python通过fnmatch模块实现文件名匹配
Sep 30 Python
Python模拟键盘输入自动登录TGP
Nov 27 Python
Python tkinter之ComboBox(下拉框)的使用简介
Feb 05 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
php 清除网页病毒的方法
2008/12/05 PHP
浅析PHP原理之变量(Variables inside PHP)
2013/08/09 PHP
通过php添加xml文档内容的方法
2015/01/23 PHP
php递归遍历多维数组的方法
2015/04/18 PHP
网页的分页下标生成代码(PHP后端方法)
2016/02/03 PHP
Laravel 实现Eloquent模型分组查询并返回每个分组的数量 groupBy()
2019/10/23 PHP
PHP ob缓存以及ob函数原理实例解析
2020/11/13 PHP
JavaScript游戏之是男人就下100层代码打包
2010/11/08 Javascript
js打开windows上的可执行文件示例
2014/05/27 Javascript
快速学习JavaScript的6个思维技巧
2015/10/13 Javascript
seajs加载jquery时提示$ is not a function该怎么解决
2015/10/23 Javascript
jQuery实现非常实用漂亮的select下拉菜单选择效果
2015/11/06 Javascript
JavaScript操作select元素和option的实例代码
2016/01/29 Javascript
微信小程序 数据访问实例详解
2016/10/08 Javascript
Vue2.x中的父子组件相互通信的实现方法
2017/05/02 Javascript
原生JS实现图片懒加载(lazyload)实例
2017/06/13 Javascript
angular2+node.js express打包部署的实战
2017/07/27 Javascript
基于vue的短信验证码倒计时demo
2017/09/13 Javascript
JS动画定时器知识总结
2018/03/23 Javascript
Javascript实现运算符重载详解
2018/04/07 Javascript
QRCode.js二维码生成并能长按识别
2018/10/16 Javascript
微信小程序实现图片选择并预览功能
2019/07/25 Javascript
JavaScript canvas实现雨滴特效
2021/01/10 Javascript
python实现的文件同步服务器实例
2015/06/02 Python
Django使用HttpResponse返回图片并显示的方法
2018/05/22 Python
python爬虫 urllib模块url编码处理详解
2019/08/20 Python
python 获取字典特定值对应的键的实现
2020/09/29 Python
美津浓美国官网:Mizuno美国
2018/08/07 全球购物
以实惠的价格轻松租车,免费取消:Easyrentcars
2019/07/16 全球购物
建龙钢铁面试总结
2014/04/15 面试题
机电系毕业生求职信
2014/07/11 职场文书
教师思想作风整顿个人剖析材料
2014/10/10 职场文书
刑事起诉书范文
2015/05/19 职场文书
2016年寒假社会实践活动总结
2015/10/10 职场文书
《走遍天下书为侣》教学反思
2016/02/22 职场文书
如何用threejs实现实时多边形折射
2021/05/07 Javascript