关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现简单的文件传输与MySQL备份的脚本分享
Jan 03 Python
简介Python的collections模块中defaultdict类型的用法
Jul 07 Python
Python基于分水岭算法解决走迷宫游戏示例
Sep 26 Python
基于numpy.random.randn()与rand()的区别详解
Apr 17 Python
用python标准库difflib比较两份文件的异同详解
Nov 16 Python
解决python3运行selenium下HTMLTestRunner报错的问题
Dec 27 Python
Python pip替换为阿里源的方法步骤
Jul 02 Python
Python项目 基于Scapy实现SYN泛洪攻击的方法
Jul 23 Python
详解Python中的format格式化函数的使用方法
Nov 20 Python
python matplotlib中的subplot函数使用详解
Jan 19 Python
django xadmin action兼容自定义model权限教程
Mar 30 Python
python中@contextmanager实例用法
Feb 07 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
在数据量大(超过10万)的情况下
2007/01/15 PHP
php excel类 phpExcel使用方法介绍
2010/08/21 PHP
PHP项目多语言配置平台实现过程解析
2020/05/18 PHP
Google Suggest ;-) 基于js的动态下拉菜单
2006/10/11 Javascript
jQuery学习笔记之jQuery的事件
2010/12/22 Javascript
常用js字符串判断方法整理
2013/10/18 Javascript
javascript替换已有元素replaceChild()使用介绍
2014/04/03 Javascript
JS来动态的修改url实现对url的增删查改
2014/09/05 Javascript
jQuery实现仿百度首页滑动伸缩展开的添加服务效果代码
2015/09/09 Javascript
基于jQuery实现图片推拉门动画效果的两种方法
2017/08/26 jQuery
利用百度echarts实现图表功能简单入门示例【附源码下载】
2019/06/10 Javascript
Nodejs技巧之Exceljs表格操作用法示例
2019/11/06 NodeJs
详解Node.js使用token进行认证的简单示例
2020/05/25 Javascript
Python中AND、OR的一个使用小技巧
2015/02/18 Python
Python中的hypot()方法使用简介
2015/05/18 Python
Python判断某个用户对某个文件的权限
2016/10/13 Python
python3+PyQt5实现自定义窗口部件Counters
2018/04/20 Python
Python读取数据集并消除数据中的空行方法
2018/07/12 Python
Python人脸识别第三方库face_recognition接口说明文档
2019/05/03 Python
Django缓存系统实现过程解析
2019/08/02 Python
使用TensorFlow直接获取处理MNIST数据方式
2020/02/10 Python
python爬虫中的url下载器用法详解
2020/11/30 Python
浅谈Html5多线程开发之WebWorkers
2018/05/02 HTML / CSS
英国家居用品和家居装饰品购物网站:Cox & Cox
2019/08/25 全球购物
4s客服专员岗位职责
2013/12/01 职场文书
租房合同协议书
2014/04/09 职场文书
护林防火标语
2014/06/27 职场文书
幼儿教师师德师风自我剖析材料
2014/09/29 职场文书
学生检讨书如何写
2014/10/30 职场文书
2015年爱国卫生工作总结
2015/04/22 职场文书
公司董事任命书
2015/09/21 职场文书
学校中层领导培训心得体会
2016/01/11 职场文书
高考升学宴主持词
2019/06/21 职场文书
7个你应该知道的JS原生错误类型
2021/04/29 Javascript
国产动画《万圣街》日语配音版制作决定!
2022/03/20 国漫
Mysql调整优化之四种分区方式以及组合分区
2022/04/13 MySQL