浅谈对pytroch中torch.autograd.backward的思考


Posted in Python onDecember 27, 2019

反向传递法则是深度学习中最为重要的一部分,torch中的backward可以对计算图中的梯度进行计算和累积

这里通过一段程序来演示基本的backward操作以及需要注意的地方

>>> import torch
>>> from torch.autograd import Variable

>>> x = Variable(torch.ones(2,2), requires_grad=True)
>>> y = x + 2
>>> y.grad_fn
Out[6]: <torch.autograd.function.AddConstantBackward at 0x229e7068138>
>>> y.grad

>>> z = y*y*3
>>> z.grad_fn
Out[9]: <torch.autograd.function.MulConstantBackward at 0x229e86cc5e8>
>>> z
Out[10]: 
Variable containing:
 27 27
 27 27
[torch.FloatTensor of size 2x2]
>>> out = z.mean()
>>> out.grad_fn
Out[12]: <torch.autograd.function.MeanBackward at 0x229e86cc408>
>>> out.backward()   # 这里因为out为scalar标量,所以参数不需要填写
>>> x.grad
Out[19]: 
Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]
>>> out  # out为标量
Out[20]: 
Variable containing:
 27
[torch.FloatTensor of size 1]

>>> x = Variable(torch.Tensor([2,2,2]), requires_grad=True)
>>> y = x*2
>>> y
Out[52]: 
Variable containing:
 4
 4
 4
[torch.FloatTensor of size 3]
>>> y.backward() # 因为y输出为非标量,求向量间元素的梯度需要对所求的元素进行标注,用相同长度的序列进行标注
Traceback (most recent call last):
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
  exec(code_obj, self.user_global_ns, self.user_ns)
 File "<ipython-input-53-95acac9c3254>", line 1, in <module>
  y.backward()
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\variable.py", line 156, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 86, in backward
  grad_variables, create_graph = _make_grads(variables, grad_variables, create_graph)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 34, in _make_grads
  raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        #注意这里的0.1,1.10为梯度求值比例
Out[55]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        # 梯度累积
Out[57]: 
Variable containing:
 0.4000
 4.0000
 40.0000
[torch.FloatTensor of size 3]

>>> x.grad.data.zero_() # 梯度累积进行清零
Out[60]: 
 0
 0
 0
[torch.FloatTensor of size 3]
>>> x.grad       # 累积为空
Out[61]: 
Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad
Out[63]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 网页解析HTMLParse的实例详解
Aug 10 Python
Python实现输出程序执行进度百分比的方法
Sep 16 Python
python虚拟环境virtualenv的使用教程
Oct 20 Python
Python实现合并同一个文件夹下所有txt文件的方法示例
Apr 26 Python
Python中的 enum 模块源码详析
Jan 09 Python
Python实现的插入排序,冒泡排序,快速排序,选择排序算法示例
May 04 Python
Pyqt5实现英文学习词典
Jun 24 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 Python
pytorch构建多模型实例
Jan 15 Python
Python flask路由间传递变量实例详解
Jun 03 Python
Python3安装模块报错Microsoft Visual C++ 14.0 is required的解决方法
Jul 28 Python
Python WebSocket长连接心跳与短连接的示例
Nov 24 Python
python实现异常信息堆栈输出到日志文件
Dec 26 #Python
Python的对象传递与Copy函数使用详解
Dec 26 #Python
Python pandas库中的isnull()详解
Dec 26 #Python
python dataframe NaN处理方式
Dec 26 #Python
python实现大战外星人小游戏实例代码
Dec 26 #Python
Python数据存储之 h5py详解
Dec 26 #Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 #Python
You might like
DOMXML函数笔记
2006/10/09 PHP
用php写的serv-u的web申请账号的程序
2006/10/09 PHP
PHP下操作Linux消息队列完成进程间通信的方法
2010/07/24 PHP
教你如何使用php session
2013/10/28 PHP
PHP防止表单重复提交的几种常用方法汇总
2014/08/19 PHP
php获取远程图片并下载保存到本地的方法分析
2016/10/08 PHP
js调试系列 源码定位与调试[基础篇]
2014/06/18 Javascript
JS获取时间的方法
2015/01/21 Javascript
深入分析JSON编码格式提交表单数据
2015/06/25 Javascript
js判断登陆用户名及密码是否为空的简单实例
2016/05/16 Javascript
js实现楼层效果的简单实例
2016/07/15 Javascript
极力推荐10个短小实用的JavaScript代码段
2016/08/03 Javascript
AngularJS入门教程之表单校验用法示例
2016/11/02 Javascript
webpack2.0配置postcss-loader的方法
2017/08/17 Javascript
Angular浏览器插件Batarang介绍及使用
2018/02/07 Javascript
解决vue中使用swiper插件问题及swiper在vue中的用法
2018/04/04 Javascript
微信小程序中遇到的iOS兼容性问题小结
2018/11/14 Javascript
[58:11]守擂赛第二周擂主赛 DeMonsTer vs Leopard
2020/04/28 DOTA
python开发之基于thread线程搜索本地文件的方法
2015/11/11 Python
python爬虫超时的处理的实例
2018/12/19 Python
Python中生成一个指定长度的随机字符串实现示例
2019/11/06 Python
Python.append()与Python.expand()用法详解
2019/12/18 Python
python plt可视化——打印特殊符号和制作图例代码
2020/04/17 Python
Python局部变量与全局变量区别原理解析
2020/07/14 Python
CSS3不透明度实例讲解
2016/04/26 HTML / CSS
介绍一下gcc特性
2015/10/31 面试题
软件工程师面试题
2012/06/25 面试题
医学生个人求职信范文
2013/09/24 职场文书
学校岗位设置方案
2014/01/16 职场文书
委托公证书样本
2015/01/23 职场文书
委托开发合同书(标准版)
2019/08/07 职场文书
六年级作文之预言作文
2019/10/25 职场文书
Python爬虫之爬取某文库文档数据
2021/04/21 Python
Python道路车道线检测的实现
2021/06/27 Python
一文了解MySQL二级索引的查询过程
2022/02/24 MySQL
css常用字体属性与背景属性介绍
2022/02/28 HTML / CSS