浅谈对pytroch中torch.autograd.backward的思考


Posted in Python onDecember 27, 2019

反向传递法则是深度学习中最为重要的一部分,torch中的backward可以对计算图中的梯度进行计算和累积

这里通过一段程序来演示基本的backward操作以及需要注意的地方

>>> import torch
>>> from torch.autograd import Variable

>>> x = Variable(torch.ones(2,2), requires_grad=True)
>>> y = x + 2
>>> y.grad_fn
Out[6]: <torch.autograd.function.AddConstantBackward at 0x229e7068138>
>>> y.grad

>>> z = y*y*3
>>> z.grad_fn
Out[9]: <torch.autograd.function.MulConstantBackward at 0x229e86cc5e8>
>>> z
Out[10]: 
Variable containing:
 27 27
 27 27
[torch.FloatTensor of size 2x2]
>>> out = z.mean()
>>> out.grad_fn
Out[12]: <torch.autograd.function.MeanBackward at 0x229e86cc408>
>>> out.backward()   # 这里因为out为scalar标量,所以参数不需要填写
>>> x.grad
Out[19]: 
Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]
>>> out  # out为标量
Out[20]: 
Variable containing:
 27
[torch.FloatTensor of size 1]

>>> x = Variable(torch.Tensor([2,2,2]), requires_grad=True)
>>> y = x*2
>>> y
Out[52]: 
Variable containing:
 4
 4
 4
[torch.FloatTensor of size 3]
>>> y.backward() # 因为y输出为非标量,求向量间元素的梯度需要对所求的元素进行标注,用相同长度的序列进行标注
Traceback (most recent call last):
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
  exec(code_obj, self.user_global_ns, self.user_ns)
 File "<ipython-input-53-95acac9c3254>", line 1, in <module>
  y.backward()
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\variable.py", line 156, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 86, in backward
  grad_variables, create_graph = _make_grads(variables, grad_variables, create_graph)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 34, in _make_grads
  raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        #注意这里的0.1,1.10为梯度求值比例
Out[55]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        # 梯度累积
Out[57]: 
Variable containing:
 0.4000
 4.0000
 40.0000
[torch.FloatTensor of size 3]

>>> x.grad.data.zero_() # 梯度累积进行清零
Out[60]: 
 0
 0
 0
[torch.FloatTensor of size 3]
>>> x.grad       # 累积为空
Out[61]: 
Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad
Out[63]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基于BeautifulSoup实现抓取网页指定内容的方法
Jul 09 Python
Python环境搭建之OpenCV的步骤方法
Oct 20 Python
python3 flask实现文件上传功能
Mar 20 Python
kafka-python批量发送数据的实例
Dec 27 Python
对python for 文件指定行读写操作详解
Dec 29 Python
python 内置模块详解
Jan 01 Python
Python常见的pandas用法demo示例
Mar 16 Python
pymongo中聚合查询的使用方法
Mar 22 Python
Python使用lambda表达式对字典排序操作示例
Jul 25 Python
python实现超市商品销售管理系统
Nov 22 Python
在python中求分布函数相关的包实例
Apr 15 Python
Python如何实现定时器功能
May 28 Python
python实现异常信息堆栈输出到日志文件
Dec 26 #Python
Python的对象传递与Copy函数使用详解
Dec 26 #Python
Python pandas库中的isnull()详解
Dec 26 #Python
python dataframe NaN处理方式
Dec 26 #Python
python实现大战外星人小游戏实例代码
Dec 26 #Python
Python数据存储之 h5py详解
Dec 26 #Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 #Python
You might like
PHP重定向的3种方式
2013/03/07 PHP
PHP中func_get_args(),func_get_arg(),func_num_args()的区别
2013/09/30 PHP
PHP json_decode函数详细解析
2014/02/17 PHP
PHP实现的自定义图像居中裁剪函数示例【测试可用】
2017/08/11 PHP
详解Laravel5.6 Passport实现Api接口认证
2018/07/27 PHP
新鲜出炉的js tips提示效果
2011/04/03 Javascript
JavaScript学习笔记(一) js基本语法
2011/10/25 Javascript
用js实现in_array的方法
2013/11/05 Javascript
jquery队列queue与原生模仿其实现方法分享
2014/03/25 Javascript
IE6中链接A的href为javascript协议时不在当前页面跳转
2014/06/05 Javascript
jQuery中removeData()方法用法实例
2014/12/27 Javascript
JavaScript中的console.group()函数详细介绍
2014/12/29 Javascript
Boostrap模态窗口的学习小结
2016/03/28 Javascript
AngularJS 过滤器的简单实例
2016/07/27 Javascript
jQuery包裹节点用法完整示例
2016/09/13 Javascript
解析Json字符串的三种方法日常常用
2018/05/02 Javascript
vue devtools的安装与使用教程
2018/08/08 Javascript
vue绑定事件后获取绑定事件中的this方法
2018/09/15 Javascript
开发一个Parcel-vue脚手架工具(详细步骤)
2018/09/22 Javascript
Element InfiniteScroll无限滚动的具体使用方法
2020/07/27 Javascript
element-ui中dialog弹窗关闭按钮失效的解决
2020/09/22 Javascript
Python获取电脑硬件信息及状态的实现方法
2014/08/29 Python
python中print的不换行即时输出的快速解决方法
2016/07/20 Python
python遍历 truple list dictionary的几种方法总结
2016/09/11 Python
python selenium 执行完毕关闭chromedriver进程示例
2019/11/15 Python
详解Python中@staticmethod和@classmethod区别及使用示例代码
2020/12/14 Python
html5 拖拽及用 js 实现拖拽功能的示例代码
2020/10/23 HTML / CSS
诗狄娜化妆品官方网站:Stila Cosmetics
2016/12/21 全球购物
法国发饰品牌:Alexandre De Paris
2018/12/04 全球购物
银行职员个人的工作自我评价
2014/02/15 职场文书
餐饮企业总经理岗位职责范文
2014/02/18 职场文书
抗震救灾标语
2014/06/26 职场文书
红白喜事主持词
2015/07/06 职场文书
SQL Server——索引+基于单表的数据插入与简单查询【1】
2021/04/05 SQL Server
mysql 8.0.24版本安装配置方法图文教程
2021/05/12 MySQL
python数据可视化JupyterLab实用扩展程序Mito
2021/11/20 Python