浅谈对pytroch中torch.autograd.backward的思考


Posted in Python onDecember 27, 2019

反向传递法则是深度学习中最为重要的一部分,torch中的backward可以对计算图中的梯度进行计算和累积

这里通过一段程序来演示基本的backward操作以及需要注意的地方

>>> import torch
>>> from torch.autograd import Variable

>>> x = Variable(torch.ones(2,2), requires_grad=True)
>>> y = x + 2
>>> y.grad_fn
Out[6]: <torch.autograd.function.AddConstantBackward at 0x229e7068138>
>>> y.grad

>>> z = y*y*3
>>> z.grad_fn
Out[9]: <torch.autograd.function.MulConstantBackward at 0x229e86cc5e8>
>>> z
Out[10]: 
Variable containing:
 27 27
 27 27
[torch.FloatTensor of size 2x2]
>>> out = z.mean()
>>> out.grad_fn
Out[12]: <torch.autograd.function.MeanBackward at 0x229e86cc408>
>>> out.backward()   # 这里因为out为scalar标量,所以参数不需要填写
>>> x.grad
Out[19]: 
Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]
>>> out  # out为标量
Out[20]: 
Variable containing:
 27
[torch.FloatTensor of size 1]

>>> x = Variable(torch.Tensor([2,2,2]), requires_grad=True)
>>> y = x*2
>>> y
Out[52]: 
Variable containing:
 4
 4
 4
[torch.FloatTensor of size 3]
>>> y.backward() # 因为y输出为非标量,求向量间元素的梯度需要对所求的元素进行标注,用相同长度的序列进行标注
Traceback (most recent call last):
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
  exec(code_obj, self.user_global_ns, self.user_ns)
 File "<ipython-input-53-95acac9c3254>", line 1, in <module>
  y.backward()
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\variable.py", line 156, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 86, in backward
  grad_variables, create_graph = _make_grads(variables, grad_variables, create_graph)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 34, in _make_grads
  raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        #注意这里的0.1,1.10为梯度求值比例
Out[55]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        # 梯度累积
Out[57]: 
Variable containing:
 0.4000
 4.0000
 40.0000
[torch.FloatTensor of size 3]

>>> x.grad.data.zero_() # 梯度累积进行清零
Out[60]: 
 0
 0
 0
[torch.FloatTensor of size 3]
>>> x.grad       # 累积为空
Out[61]: 
Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad
Out[63]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈pandas中shift和diff函数关系
Apr 08 Python
python使用socket创建tcp服务器和客户端
Apr 12 Python
python实现屏保计时器的示例代码
Aug 08 Python
使用python实现快速搭建简易的FTP服务器
Sep 12 Python
Python面向对象之类的内置attr属性示例
Dec 14 Python
Django组件cookie与session的具体使用
Jun 05 Python
基于Python函数和变量名解析
Jul 19 Python
python实现将range()函数生成的数字存储在一个列表中
Apr 02 Python
Python基于staticmethod装饰器标示静态方法
Oct 17 Python
Python中的面向接口编程示例详解
Jan 17 Python
python爬取抖音视频的实例分析
Jan 19 Python
python编写函数注意事项总结
Mar 29 Python
python实现异常信息堆栈输出到日志文件
Dec 26 #Python
Python的对象传递与Copy函数使用详解
Dec 26 #Python
Python pandas库中的isnull()详解
Dec 26 #Python
python dataframe NaN处理方式
Dec 26 #Python
python实现大战外星人小游戏实例代码
Dec 26 #Python
Python数据存储之 h5py详解
Dec 26 #Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 #Python
You might like
PHP 应用程序的安全 -- 不能违反的四条安全规则
2006/11/26 PHP
php基于base64解码图片与加密图片还原实例
2014/11/03 PHP
CodeIgniter分页类pagination使用方法示例
2016/03/28 PHP
php版微信公众平台回复中文出现乱码问题的解决方法
2016/09/22 PHP
php爬取天猫和淘宝商品数据
2018/02/23 PHP
php实现websocket实时消息推送
2018/03/30 PHP
PHP基于DateTime类解决Unix时间戳与日期互转问题【针对1970年前及2038年后时间戳】
2018/06/13 PHP
JavaScript 权威指南(第四版) 读书笔记
2009/08/11 Javascript
JavaScript面向对象知识串结(读JavaScript高级程序设计(第三版))
2012/07/17 Javascript
简述JavaScript对传统文档对象模型的支持
2015/06/16 Javascript
详解JavaScript语言的基本语法要求
2015/11/20 Javascript
JS平滑无缝滚动效果的实现代码
2016/05/06 Javascript
javascript宿主对象之window.navigator详解
2016/09/07 Javascript
移动端web滚动分页的实现方法
2017/05/05 Javascript
详解基于Node.js的微信JS-SDK后端接口实现代码
2017/07/15 Javascript
浅谈Vue.js 1.x 和 2.x 实例的生命周期
2017/07/25 Javascript
实现单层json按照key字母顺序排序的示例
2017/12/06 Javascript
vue-i18n结合Element-ui的配置方法
2019/05/20 Javascript
jQuery实现高度灵活的表单验证功能示例【无UI】
2020/04/30 jQuery
python处理文本文件并生成指定格式的文件
2014/07/31 Python
Python中实现常量(Const)功能
2015/01/28 Python
Python访问纯真IP数据库脚本分享
2015/06/29 Python
python如何拆分含有多种分隔符的字符串
2018/03/20 Python
基于Python的图像阈值化分割(迭代法)
2020/11/20 Python
Python脚本调试工具安装过程
2021/01/11 Python
马来西亚在线时尚女装商店:KEI MAG
2017/09/28 全球购物
Koral官方网站:女性时尚运动服
2019/04/10 全球购物
护士毕业生自荐信
2014/02/07 职场文书
2014年清明节寄语
2014/04/03 职场文书
幼儿园标语大全
2014/06/19 职场文书
计生工作先进事迹
2014/08/15 职场文书
教师三严三实对照检查材料
2014/09/25 职场文书
2014年仓库管理员工作总结
2014/11/18 职场文书
2015年度村委会工作总结
2015/04/29 职场文书
教师节联欢会主持词
2015/07/04 职场文书
Python编写冷笑话生成器
2022/04/20 Python