浅谈对pytroch中torch.autograd.backward的思考


Posted in Python onDecember 27, 2019

反向传递法则是深度学习中最为重要的一部分,torch中的backward可以对计算图中的梯度进行计算和累积

这里通过一段程序来演示基本的backward操作以及需要注意的地方

>>> import torch
>>> from torch.autograd import Variable

>>> x = Variable(torch.ones(2,2), requires_grad=True)
>>> y = x + 2
>>> y.grad_fn
Out[6]: <torch.autograd.function.AddConstantBackward at 0x229e7068138>
>>> y.grad

>>> z = y*y*3
>>> z.grad_fn
Out[9]: <torch.autograd.function.MulConstantBackward at 0x229e86cc5e8>
>>> z
Out[10]: 
Variable containing:
 27 27
 27 27
[torch.FloatTensor of size 2x2]
>>> out = z.mean()
>>> out.grad_fn
Out[12]: <torch.autograd.function.MeanBackward at 0x229e86cc408>
>>> out.backward()   # 这里因为out为scalar标量,所以参数不需要填写
>>> x.grad
Out[19]: 
Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]
>>> out  # out为标量
Out[20]: 
Variable containing:
 27
[torch.FloatTensor of size 1]

>>> x = Variable(torch.Tensor([2,2,2]), requires_grad=True)
>>> y = x*2
>>> y
Out[52]: 
Variable containing:
 4
 4
 4
[torch.FloatTensor of size 3]
>>> y.backward() # 因为y输出为非标量,求向量间元素的梯度需要对所求的元素进行标注,用相同长度的序列进行标注
Traceback (most recent call last):
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
  exec(code_obj, self.user_global_ns, self.user_ns)
 File "<ipython-input-53-95acac9c3254>", line 1, in <module>
  y.backward()
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\variable.py", line 156, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 86, in backward
  grad_variables, create_graph = _make_grads(variables, grad_variables, create_graph)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 34, in _make_grads
  raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        #注意这里的0.1,1.10为梯度求值比例
Out[55]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        # 梯度累积
Out[57]: 
Variable containing:
 0.4000
 4.0000
 40.0000
[torch.FloatTensor of size 3]

>>> x.grad.data.zero_() # 梯度累积进行清零
Out[60]: 
 0
 0
 0
[torch.FloatTensor of size 3]
>>> x.grad       # 累积为空
Out[61]: 
Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad
Out[63]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 获取当天凌晨零点的时间戳方法
May 22 Python
python实现windows下文件备份脚本
May 27 Python
Python3对称加密算法AES、DES3实例详解
Dec 06 Python
django中ORM模型常用的字段的使用方法
Mar 05 Python
python初学者,用python实现基本的学生管理系统(python3)代码实例
Apr 10 Python
Django框架搭建的简易图书信息网站案例
May 25 Python
使用pyecharts生成Echarts网页的实例
Aug 12 Python
py-charm延长试用期限实例
Dec 22 Python
Python3搭建http服务器的实现代码
Feb 11 Python
TensorFlow2.1.0安装过程中setuptools、wrapt等相关错误指南
Apr 08 Python
python实现斗地主分牌洗牌
Jun 22 Python
Python Matplotlib绘图基础知识代码解析
Aug 31 Python
python实现异常信息堆栈输出到日志文件
Dec 26 #Python
Python的对象传递与Copy函数使用详解
Dec 26 #Python
Python pandas库中的isnull()详解
Dec 26 #Python
python dataframe NaN处理方式
Dec 26 #Python
python实现大战外星人小游戏实例代码
Dec 26 #Python
Python数据存储之 h5py详解
Dec 26 #Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 #Python
You might like
实现在同一方法中获取当前方法中新赋值的session值解决方法
2014/06/26 PHP
利用PHPExcel读取Excel的数据和导出数据到Excel
2017/05/12 PHP
ThinkPHP框架获取最后一次执行SQL语句及变量调试简单操作示例
2018/06/13 PHP
jquery.lazyload  实现图片延迟加载jquery插件
2010/02/06 Javascript
基于jquery的可多选的下拉列表框
2012/07/20 Javascript
浅谈JavaScript字符集
2014/05/22 Javascript
Javascript显示和隐藏ul列表的方法
2015/07/15 Javascript
Node.js静态文件服务器改进版
2016/01/10 Javascript
js enter键激发事件实例代码
2016/08/17 Javascript
Bootstrap基本组件学习笔记之列表组(11)
2016/12/07 Javascript
BootStrap的select2既可以查询又可以输入的实现代码
2017/02/17 Javascript
vue.js从安装到搭建过程详解
2017/03/17 Javascript
JavaScript判断变量名是否存在数组中的实例
2017/12/28 Javascript
Vue全局分页组件的实现代码
2018/08/10 Javascript
jquery操作select常见方法大全【7种情况】
2019/05/28 jQuery
通过实践编写优雅的JavaScript代码
2019/05/30 Javascript
vue eslint简要配置教程详解
2019/07/26 Javascript
[01:38]DOTA2 2015国际邀请赛中国区预选赛 Showopen
2015/06/01 DOTA
Python判断某个用户对某个文件的权限
2016/10/13 Python
Python基于Logistic回归建模计算某银行在降低贷款拖欠率的数据示例
2019/01/23 Python
Python 窗体(tkinter)按钮 位置实例
2019/06/13 Python
pyQT5 实现窗体之间传值的示例
2019/06/20 Python
Python3 main函数使用sys.argv传入多个参数的实现
2019/12/25 Python
python爬虫筛选工作实例讲解
2020/11/23 Python
纯html5+css3下拉导航菜单实现代码
2013/03/18 HTML / CSS
解决canvas转base64/jpeg时透明区域变成黑色背景的方法
2016/10/23 HTML / CSS
this关键字的含义
2015/04/08 面试题
说说在weblogic中开发消息Bean时的persistent与non-persisten的差别
2013/04/07 面试题
市场营销专业毕业生求职信
2014/07/21 职场文书
2014物价局民主生活会对照检查材料思想汇报
2014/09/24 职场文书
2015年护士医德医风自我评价
2015/03/03 职场文书
铁人纪念馆观后感
2015/06/16 职场文书
毕业典礼主持词
2015/06/29 职场文书
2015年中秋晚会主持稿
2015/07/30 职场文书
PyQt5 显示超清高分辨率图片的方法
2021/04/11 Python
python manim实现排序算法动画示例
2022/08/14 Python