关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python计数排序和基数排序算法实例
Apr 25 Python
Python编程中对文件和存储器的读写示例
Jan 25 Python
python实现装饰器、描述符
Feb 28 Python
Python 中的range(),以及列表切片方法
Jul 02 Python
在python中使用with打开多个文件的方法
Jan 07 Python
Django文件存储 默认存储系统解析
Aug 02 Python
Python 操作 ElasticSearch的完整代码
Aug 04 Python
python hashlib加密实现代码
Oct 17 Python
Python 输出详细的异常信息(traceback)方式
Apr 08 Python
Python+redis通过限流保护高并发系统
Apr 15 Python
Django中celery的使用项目实例
Jul 07 Python
python解析照片拍摄时间进行图片整理
Jul 23 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
mysq GBKl乱码
2006/11/28 PHP
PHP类的使用 实例代码讲解
2009/12/28 PHP
php-cli简介(不会Shell语言一样用Shell)
2013/06/03 PHP
PHP 关于访问控制的和运算符优先级介绍
2013/07/08 PHP
php缩小png图片不损失透明色的解决方法
2013/12/25 PHP
初识laravel5
2015/03/02 PHP
php实现复制移动文件的方法
2015/07/29 PHP
PHP设计模式之观察者模式实例
2016/02/22 PHP
js实现全屏漂浮广告移入光标停止移动
2013/12/02 Javascript
一段非常简单的js判断浏览器的内核
2014/08/17 Javascript
通过JS来动态的修改url,实现对url的增删查改
2014/09/01 Javascript
JS实现简单的星期格式转换功能示例
2018/07/23 Javascript
vue2使用keep-alive缓存多层列表页的方法
2018/09/21 Javascript
实现Vue的markdown文档可以在线运行的方法示例
2018/12/11 Javascript
JavaScript对象属性操作实例解析
2020/02/04 Javascript
jQuery表单校验插件validator使用方法详解
2020/02/18 jQuery
JavaScript 链表定义与使用方法示例
2020/04/28 Javascript
JS数组push、unshift、pop、shift方法的实现与使用方法示例
2020/04/29 Javascript
Python脚本实现Web漏洞扫描工具
2016/10/25 Python
pandas数据清洗,排序,索引设置,数据选取方法
2018/05/18 Python
python使用rpc框架gRPC的方法
2018/08/24 Python
Python图像滤波处理操作示例【基于ImageFilter类】
2019/01/03 Python
Python进程间通信 multiProcessing Queue队列实现详解
2019/09/23 Python
pyinstaller打包程序exe踩过的坑
2019/11/19 Python
CSS3实现缺角矩形,折角矩形以及缺角边框
2019/12/20 HTML / CSS
Perfumetrader荷兰:香水、化妆品和护肤品在线商店
2017/09/15 全球购物
EQVVS官网:设计师男装和女装
2018/10/24 全球购物
有原因的手表:Flex Watches
2019/03/23 全球购物
什么是典型的软件三层结构?软件设计为什么要分层?软件分层有什么好处?
2012/03/14 面试题
远程调用的原理
2014/07/05 面试题
会计与审计毕业生自荐信范文
2013/12/30 职场文书
水果连锁超市创业计划书
2014/01/24 职场文书
军训自我鉴定怎么写
2014/02/13 职场文书
安全责任书范本
2014/04/15 职场文书
房产授权委托书范本
2014/09/22 职场文书
什么是检讨书?检讨书的格式及范文
2019/11/05 职场文书