在pytorch中实现只让指定变量向后传播梯度


Posted in Python onFebruary 29, 2020

pytorch中如何只让指定变量向后传播梯度?

(或者说如何让指定变量不参与后向传播?)

有以下公式,假如要让L对xvar求导:

在pytorch中实现只让指定变量向后传播梯度

(1)中,L对xvar的求导将同时计算out1部分和out2部分;

(2)中,L对xvar的求导只计算out2部分,因为out1的requires_grad=False;

(3)中,L对xvar的求导只计算out1部分,因为out2的requires_grad=False;

验证如下:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed May 23 10:02:04 2018
@author: hy
"""
 
import torch
from torch.autograd import Variable
print("Pytorch version: {}".format(torch.__version__))
x=torch.Tensor([1])
xvar=Variable(x,requires_grad=True)
y1=torch.Tensor([2])
y2=torch.Tensor([7])
y1var=Variable(y1)
y2var=Variable(y2)
#(1)
print("For (1)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(2)
print("For (2)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
out1 = out1.detach()
print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(3)
print("For (3)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
#out1 = out1.detach()
out2 = out2.detach()
print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()

pytorch中,将变量的requires_grad设为False,即可让变量不参与梯度的后向传播;

但是不能直接将out1.requires_grad=False;

其实,Variable类型提供了detach()方法,所返回变量的requires_grad为False。

注意:如果out1和out2的requires_grad都为False的话,那么xvar.grad就出错了,因为梯度没有传到xvar

补充:

volatile=True表示这个变量不计算梯度, 参考:Volatile is recommended for purely inference mode, when you're sure you won't be even calling .backward(). It's more efficient than any other autograd setting - it will use the absolute minimal amount of memory to evaluate the model. volatile also determines that requires_grad is False.

以上这篇在pytorch中实现只让指定变量向后传播梯度就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python简单日志处理类分享
Feb 14 Python
Python实现求最大公约数及判断素数的方法
May 26 Python
Python编程中的文件读写及相关的文件对象方法讲解
Jan 19 Python
Python使用自带的ConfigParser模块读写ini配置文件
Jun 26 Python
详解pyqt5 动画在QThread线程中无法运行问题
May 05 Python
在PyCharm环境中使用Jupyter Notebook的两种方法总结
May 24 Python
Pycharm之快速定位到某行快捷键的方法
Jan 20 Python
Python判断对象是否为文件对象(file object)的三种方法示例
Apr 26 Python
Python和Java的语法对比分析语法简洁上python的确完美胜出
May 10 Python
用Python从0开始实现一个中文拼音输入法的思路详解
Jul 20 Python
django的聚合函数和aggregate、annotate方法使用详解
Jul 23 Python
numpy按列连接两个维数不同的数组方式
Dec 06 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
Feb 29 #Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
You might like
手把手教你使用DedeCms V3的在线采集图文教程
2007/04/03 PHP
教你识别简单的免查杀PHP后门
2015/09/13 PHP
教你在header中隐藏php的版本信息
2016/08/10 PHP
thinkPHP统计排行与分页显示功能示例
2016/12/02 PHP
php使用自定义函数实现汉字分割替换功能示例
2017/01/30 PHP
JavaScript isArray()函数判断对象类型的种种方法
2010/10/11 Javascript
浅析jQuery对select操作小结(遍历option,操作option)
2013/07/04 Javascript
javascript预加载图片、css、js的方法示例介绍
2013/10/14 Javascript
js写出遮罩层登陆框和对联广告并自动跟随滚动条滚动
2014/04/29 Javascript
在css加载完毕后自动判断页面是否加入css或js文件
2014/09/10 Javascript
输入框过滤非数字的js代码
2014/09/18 Javascript
基于JavaScript代码实现pc与手机之间的跳转
2015/12/23 Javascript
微信小程序  wx.request合法域名配置详解
2016/11/23 Javascript
jQuery插件HighCharts绘制简单2D柱状图效果示例【附demo源码】
2017/03/21 jQuery
详解angular中的作用域及继承
2017/05/31 Javascript
详解可以用在VS Code中的正则表达式小技巧
2019/05/14 Javascript
Vue3 的响应式和以前有什么区别,Proxy 无敌?
2020/05/20 Javascript
[14:36]2014 DOTA2国际邀请赛中国区预选赛5.21 Orenda VS NE
2014/05/22 DOTA
对于Python异常处理慎用“except:pass”建议
2015/04/02 Python
自己编程中遇到的Python错误和解决方法汇总整理
2015/06/03 Python
python自动12306抢票软件实现代码
2018/02/24 Python
深入浅析Python获取对象信息的函数type()、isinstance()、dir()
2018/09/17 Python
基于Python实现迪杰斯特拉和弗洛伊德算法
2020/05/27 Python
python生成每日报表数据(Excel)并邮件发送的实例
2019/02/03 Python
python数据分析工具之 matplotlib详解
2020/04/09 Python
Python DataFrame使用drop_duplicates()函数去重(保留重复值,取重复值)
2020/07/20 Python
django有哪些好处和优点
2020/09/01 Python
Python爬虫之Selenium多窗口切换的实现
2020/12/04 Python
英国著名的茶叶品牌:Whittard of Chelsea
2016/09/22 全球购物
男女时尚与复古风格在线购物:RoseGal(全球免费送货)
2017/07/19 全球购物
白俄罗斯大卖场:21vek.by
2019/07/25 全球购物
德国大型箱包和皮具商店:Koffer
2019/10/01 全球购物
中科前程Java笔试题
2016/11/20 面试题
西式结婚主持词
2014/03/14 职场文书
停车场管理协议书范本
2014/10/08 职场文书
社区植树节活动总结
2015/02/06 职场文书