浅谈对pytroch中torch.autograd.backward的思考


Posted in Python onDecember 27, 2019

反向传递法则是深度学习中最为重要的一部分,torch中的backward可以对计算图中的梯度进行计算和累积

这里通过一段程序来演示基本的backward操作以及需要注意的地方

>>> import torch
>>> from torch.autograd import Variable

>>> x = Variable(torch.ones(2,2), requires_grad=True)
>>> y = x + 2
>>> y.grad_fn
Out[6]: <torch.autograd.function.AddConstantBackward at 0x229e7068138>
>>> y.grad

>>> z = y*y*3
>>> z.grad_fn
Out[9]: <torch.autograd.function.MulConstantBackward at 0x229e86cc5e8>
>>> z
Out[10]: 
Variable containing:
 27 27
 27 27
[torch.FloatTensor of size 2x2]
>>> out = z.mean()
>>> out.grad_fn
Out[12]: <torch.autograd.function.MeanBackward at 0x229e86cc408>
>>> out.backward()   # 这里因为out为scalar标量,所以参数不需要填写
>>> x.grad
Out[19]: 
Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]
>>> out  # out为标量
Out[20]: 
Variable containing:
 27
[torch.FloatTensor of size 1]

>>> x = Variable(torch.Tensor([2,2,2]), requires_grad=True)
>>> y = x*2
>>> y
Out[52]: 
Variable containing:
 4
 4
 4
[torch.FloatTensor of size 3]
>>> y.backward() # 因为y输出为非标量,求向量间元素的梯度需要对所求的元素进行标注,用相同长度的序列进行标注
Traceback (most recent call last):
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
  exec(code_obj, self.user_global_ns, self.user_ns)
 File "<ipython-input-53-95acac9c3254>", line 1, in <module>
  y.backward()
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\variable.py", line 156, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 86, in backward
  grad_variables, create_graph = _make_grads(variables, grad_variables, create_graph)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 34, in _make_grads
  raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        #注意这里的0.1,1.10为梯度求值比例
Out[55]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        # 梯度累积
Out[57]: 
Variable containing:
 0.4000
 4.0000
 40.0000
[torch.FloatTensor of size 3]

>>> x.grad.data.zero_() # 梯度累积进行清零
Out[60]: 
 0
 0
 0
[torch.FloatTensor of size 3]
>>> x.grad       # 累积为空
Out[61]: 
Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad
Out[63]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python set集合类型操作总结
Nov 07 Python
Python2.x和3.x下maketrans与translate函数使用上的不同
Apr 13 Python
python从入门到精通(DAY 1)
Dec 20 Python
简介Python设计模式中的代理模式与模板方法模式编程
Feb 02 Python
利用PyInstaller将python程序.py转为.exe的方法详解
May 03 Python
Python机器学习之决策树算法
Dec 22 Python
对python中使用requests模块参数编码的不同处理方法
May 18 Python
Pandas 按索引合并数据集的方法
Nov 15 Python
python 将字符串中的数字相加求和的实现
Jul 18 Python
Python使用qrcode二维码库生成二维码方法详解
Feb 17 Python
python Yaml、Json、Dict之间的转化
Oct 19 Python
python套接字socket通信
Apr 01 Python
python实现异常信息堆栈输出到日志文件
Dec 26 #Python
Python的对象传递与Copy函数使用详解
Dec 26 #Python
Python pandas库中的isnull()详解
Dec 26 #Python
python dataframe NaN处理方式
Dec 26 #Python
python实现大战外星人小游戏实例代码
Dec 26 #Python
Python数据存储之 h5py详解
Dec 26 #Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 #Python
You might like
codeigniter发送邮件并打印调试信息的方法
2015/03/21 PHP
Thinkphp5.0自动生成模块及目录的方法详解
2017/04/17 PHP
thinkphp框架表单数组实现图片批量上传功能示例
2020/04/04 PHP
javascript 系统文件夹文件操作及参数介绍
2013/01/08 Javascript
JavaScript设置首页和收藏页面的小例子
2013/11/11 Javascript
js获取当前路径的简单示例代码
2014/01/08 Javascript
div失去焦点事件实现思路
2014/04/22 Javascript
Jquery解析json字符串及json数组的方法
2015/05/29 Javascript
js下拉选择框与输入框联动实现添加选中值到输入框的方法
2015/08/17 Javascript
JavaScript学习笔记(三):JavaScript也有入口Main函数
2015/09/12 Javascript
整理Javascript基础入门学习笔记
2015/11/29 Javascript
js判断手机访问或者PC的几个例子(常用于手机跳转)
2015/12/15 Javascript
用js将long型数据转换成date型或datetime型的实例
2017/07/03 Javascript
JS使用setInterval实现的简单计时器功能示例
2018/04/19 Javascript
vue-cli项目代理proxyTable配置exclude的方法
2018/09/20 Javascript
使用Angular Cli如何创建Angular私有库详解
2019/01/30 Javascript
微信小程序学习笔记之函数定义、页面渲染图文详解
2019/03/28 Javascript
python高并发异步服务器核心库forkcore使用方法
2013/11/26 Python
python实现验证码识别功能
2018/06/07 Python
Django管理员账号和密码忘记的完美解决方法
2018/12/06 Python
python批量下载抖音视频
2019/06/17 Python
python flask 如何修改默认端口号的方法步骤
2019/07/12 Python
pip install 使用国内镜像的方法示例
2020/04/03 Python
python按顺序重命名文件并分类转移到各个文件夹中的实现代码
2020/07/21 Python
2020版Python学习路线图(附学习资料)
2020/09/15 Python
详解Python yaml模块
2020/09/23 Python
DRF使用simple JWT身份验证的实现
2021/01/14 Python
上海雨人软件技术开发有限公司测试题
2015/07/14 面试题
高三自我鉴定怎么写
2013/10/19 职场文书
机关党员三严三实心得体会
2014/10/13 职场文书
融资合作协议书范本
2014/10/17 职场文书
个人年终总结开头
2015/03/06 职场文书
2015年保洁工作总结范文
2015/04/28 职场文书
css实现两栏布局,左侧固定宽,右侧自适应的多种方法
2021/08/07 HTML / CSS
Nginx虚拟主机的搭建的实现步骤
2022/01/18 Servers
Win11 KB5015814遇安装失败 影响开始菜单性能解决方法
2022/07/15 数码科技