在pytorch中实现只让指定变量向后传播梯度


Posted in Python onFebruary 29, 2020

pytorch中如何只让指定变量向后传播梯度?

(或者说如何让指定变量不参与后向传播?)

有以下公式,假如要让L对xvar求导:

在pytorch中实现只让指定变量向后传播梯度

(1)中,L对xvar的求导将同时计算out1部分和out2部分;

(2)中,L对xvar的求导只计算out2部分,因为out1的requires_grad=False;

(3)中,L对xvar的求导只计算out1部分,因为out2的requires_grad=False;

验证如下:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed May 23 10:02:04 2018
@author: hy
"""
 
import torch
from torch.autograd import Variable
print("Pytorch version: {}".format(torch.__version__))
x=torch.Tensor([1])
xvar=Variable(x,requires_grad=True)
y1=torch.Tensor([2])
y2=torch.Tensor([7])
y1var=Variable(y1)
y2var=Variable(y2)
#(1)
print("For (1)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(2)
print("For (2)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
out1 = out1.detach()
print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(3)
print("For (3)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
#out1 = out1.detach()
out2 = out2.detach()
print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()

pytorch中,将变量的requires_grad设为False,即可让变量不参与梯度的后向传播;

但是不能直接将out1.requires_grad=False;

其实,Variable类型提供了detach()方法,所返回变量的requires_grad为False。

注意:如果out1和out2的requires_grad都为False的话,那么xvar.grad就出错了,因为梯度没有传到xvar

补充:

volatile=True表示这个变量不计算梯度, 参考:Volatile is recommended for purely inference mode, when you're sure you won't be even calling .backward(). It's more efficient than any other autograd setting - it will use the absolute minimal amount of memory to evaluate the model. volatile also determines that requires_grad is False.

以上这篇在pytorch中实现只让指定变量向后传播梯度就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 装饰器功能以及函数参数使用介绍
Jan 27 Python
python通过yield实现数组全排列的方法
Mar 18 Python
Python的Django框架中的表单处理示例
Jul 17 Python
浅谈python for循环的巧妙运用(迭代、列表生成式)
Sep 26 Python
Python实现扩展内置类型的方法分析
Oct 16 Python
Python实现感知机(PLA)算法
Dec 20 Python
Python自定义函数定义,参数,调用代码解析
Dec 27 Python
pygame实现简易飞机大战
Sep 11 Python
python tkinter canvas 显示图片的示例
Jun 13 Python
Python 函数用法简单示例【定义、参数、返回值、函数嵌套】
Sep 20 Python
解决pycharm每次打开项目都需要配置解释器和安装库问题
Feb 26 Python
Python+Selenium实现读取网易邮箱验证码
Mar 13 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
Feb 29 #Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
You might like
星际中一些鲜为人知的详细资料
2020/03/04 星际争霸
PHP也可以?成Shell Script
2006/10/09 PHP
PHP函数篇详解十进制、二进制、八进制和十六进制转换函数说明
2011/12/05 PHP
JavaScript实现滚动栏效果的方法
2015/04/27 PHP
Zend Framework教程之Loader以及PluginLoader用法详解
2016/03/09 PHP
PHP请求远程地址设置超时时间的解决方法
2016/10/29 PHP
goto语法在PHP中的使用教程
2020/09/17 PHP
仅IE不支持setTimeout/setInterval函数的第三个以上参数
2011/05/25 Javascript
使用 JScript 创建 .exe 或 .dll 文件的方法
2011/07/13 Javascript
深入分析JSON编码格式提交表单数据
2015/06/25 Javascript
JS+CSS实现仿msn风格选项卡效果代码
2015/10/22 Javascript
JavaScript如何调试有哪些建议和技巧附五款有用的调试工具
2015/10/28 Javascript
JS实现的N多简单无缝滚动代码(包含图文效果)
2015/11/06 Javascript
详解AngularJS中自定义过滤器
2015/12/28 Javascript
Bootstrap提示框效果的实例代码
2017/07/12 Javascript
angular2系列之路由转场动画的示例代码
2017/11/09 Javascript
原生JS实现网页手机音乐播放器 歌词同步播放的示例
2018/02/02 Javascript
Vue实现根据hash高亮选项卡
2019/05/27 Javascript
微信小程序自定义头部导航栏(组件化)
2019/11/15 Javascript
js实现验证码功能
2020/07/24 Javascript
javascript实现打砖块小游戏(附完整源码)
2020/09/18 Javascript
Vue获取微博授权URL代码实例
2020/11/04 Javascript
[03:54]DOTA2英雄梦之声_第06期_昆卡
2014/06/23 DOTA
Python中类的定义、继承及使用对象实例详解
2015/04/30 Python
Python元组操作实例分析【创建、赋值、更新、删除等】
2017/07/24 Python
python关于矩阵重复赋值覆盖问题的解决方法
2019/07/19 Python
18个Python脚本可加速你的编码速度(提示和技巧)
2019/10/17 Python
Pandas把dataframe或series转换成list的方法
2020/06/14 Python
解释一下抽象方法和抽象类
2016/08/27 面试题
数控专业个人求职信范例
2013/11/29 职场文书
司机检讨书
2014/02/13 职场文书
党员群众路线剖析材料
2014/10/08 职场文书
国庆横幅标语
2014/10/08 职场文书
焦点访谈观后感
2015/06/11 职场文书
2015年小学财务工作总结
2015/07/20 职场文书
Java并发编程之Executor接口的使用
2021/06/21 Java/Android