浅谈对pytroch中torch.autograd.backward的思考


Posted in Python onDecember 27, 2019

反向传递法则是深度学习中最为重要的一部分,torch中的backward可以对计算图中的梯度进行计算和累积

这里通过一段程序来演示基本的backward操作以及需要注意的地方

>>> import torch
>>> from torch.autograd import Variable

>>> x = Variable(torch.ones(2,2), requires_grad=True)
>>> y = x + 2
>>> y.grad_fn
Out[6]: <torch.autograd.function.AddConstantBackward at 0x229e7068138>
>>> y.grad

>>> z = y*y*3
>>> z.grad_fn
Out[9]: <torch.autograd.function.MulConstantBackward at 0x229e86cc5e8>
>>> z
Out[10]: 
Variable containing:
 27 27
 27 27
[torch.FloatTensor of size 2x2]
>>> out = z.mean()
>>> out.grad_fn
Out[12]: <torch.autograd.function.MeanBackward at 0x229e86cc408>
>>> out.backward()   # 这里因为out为scalar标量,所以参数不需要填写
>>> x.grad
Out[19]: 
Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]
>>> out  # out为标量
Out[20]: 
Variable containing:
 27
[torch.FloatTensor of size 1]

>>> x = Variable(torch.Tensor([2,2,2]), requires_grad=True)
>>> y = x*2
>>> y
Out[52]: 
Variable containing:
 4
 4
 4
[torch.FloatTensor of size 3]
>>> y.backward() # 因为y输出为非标量,求向量间元素的梯度需要对所求的元素进行标注,用相同长度的序列进行标注
Traceback (most recent call last):
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
  exec(code_obj, self.user_global_ns, self.user_ns)
 File "<ipython-input-53-95acac9c3254>", line 1, in <module>
  y.backward()
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\variable.py", line 156, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 86, in backward
  grad_variables, create_graph = _make_grads(variables, grad_variables, create_graph)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 34, in _make_grads
  raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        #注意这里的0.1,1.10为梯度求值比例
Out[55]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        # 梯度累积
Out[57]: 
Variable containing:
 0.4000
 4.0000
 40.0000
[torch.FloatTensor of size 3]

>>> x.grad.data.zero_() # 梯度累积进行清零
Out[60]: 
 0
 0
 0
[torch.FloatTensor of size 3]
>>> x.grad       # 累积为空
Out[61]: 
Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad
Out[63]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python二维码生成库qrcode安装和使用示例
Dec 16 Python
python+Django+apache的配置方法详解
Jun 01 Python
python在Windows下安装setuptools(easy_install工具)步骤详解
Jul 01 Python
Python hashlib模块用法实例分析
Jun 12 Python
对python 数据处理中的LabelEncoder 和 OneHotEncoder详解
Jul 11 Python
win7 x64系统中安装Scrapy的方法
Nov 18 Python
使用Python-OpenCV向图片添加噪声的实现(高斯噪声、椒盐噪声)
May 28 Python
python Django 创建应用过程图示详解
Jul 29 Python
python库matplotlib绘制坐标图
Oct 18 Python
Python 代码调试技巧示例代码
Aug 11 Python
Python经常使用的一些内置函数
Apr 11 Python
python数字图像处理实现图像的形变与缩放
Jun 28 Python
python实现异常信息堆栈输出到日志文件
Dec 26 #Python
Python的对象传递与Copy函数使用详解
Dec 26 #Python
Python pandas库中的isnull()详解
Dec 26 #Python
python dataframe NaN处理方式
Dec 26 #Python
python实现大战外星人小游戏实例代码
Dec 26 #Python
Python数据存储之 h5py详解
Dec 26 #Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 #Python
You might like
mysql中存储过程、函数的一些问题
2007/02/14 PHP
PHP字符转义相关函数小结(php下的转义字符串)
2007/04/12 PHP
JavaScript call apply使用 JavaScript对象的方法绑定到DOM事件后this指向问题
2011/09/28 Javascript
Js 回车换行处理的办法及replace方法应用
2013/01/24 Javascript
jquery弹出关闭遮罩层实例
2013/08/06 Javascript
探讨jQuery的ajax使用场景(c#)
2013/12/03 Javascript
浅谈JavaScript函数参数的可修改性问题
2013/12/05 Javascript
JQuery调用WebServices的方法和4个实例
2014/05/06 Javascript
AngularJS中$interval的用法详解
2016/02/02 Javascript
详解VUE自定义组件中用.sync修饰符与v-model的区别
2018/06/26 Javascript
详解微信小程序scroll-view横向滚动的实践踩坑及隐藏其滚动条的实现
2019/03/14 Javascript
vue-cli webpack配置文件分析
2019/05/20 Javascript
详解vue-cli3多页应用改造
2019/06/04 Javascript
ES6基础之字符串和函数的拓展详解
2019/08/22 Javascript
vue实现倒计时获取验证码效果
2020/04/17 Javascript
简单使用webpack打包文件的实现
2019/10/29 Javascript
selenium 反爬虫之跳过淘宝滑块验证功能的实现代码
2020/08/27 Javascript
[01:09]DOTA2次级职业联赛 - 99战队宣传片
2014/12/01 DOTA
[04:23]DOTA2上海特锦赛小组赛第一日 TOP10精彩集锦
2016/02/27 DOTA
Python通过PIL获取图片主要颜色并和颜色库进行对比的方法
2015/03/19 Python
Python新手在作用域方面经常容易碰到的问题
2015/04/03 Python
一步步解析Python斗牛游戏的概率
2016/02/12 Python
详解Python的Twisted框架中reactor事件管理器的用法
2016/05/25 Python
python pandas dataframe 行列选择,切片操作方法
2018/04/10 Python
TensorFlow 合并/连接数组的方法
2018/07/27 Python
Python3列表内置方法大全及示例代码小结
2019/05/10 Python
Django框架实现分页显示内容的方法详解
2019/05/10 Python
解决Django提交表单报错:CSRF token missing or incorrect的问题
2020/03/13 Python
浅谈Python程序的错误:变量未定义
2020/06/02 Python
anaconda安装pytorch1.7.1和torchvision0.8.2的方法(亲测可用)
2021/02/01 Python
实习单位鉴定评语
2014/04/26 职场文书
2015年乡镇食品安全工作总结
2015/10/22 职场文书
党员干部学法用法心得体会
2016/01/21 职场文书
数学复习课教学反思
2016/02/18 职场文书
CSS3 菱形拼图实现只旋转div 背景图片不旋转功能
2021/03/30 HTML / CSS
vue使用echarts实现折线图
2022/03/21 Vue.js