在pytorch中实现只让指定变量向后传播梯度


Posted in Python onFebruary 29, 2020

pytorch中如何只让指定变量向后传播梯度?

(或者说如何让指定变量不参与后向传播?)

有以下公式,假如要让L对xvar求导:

在pytorch中实现只让指定变量向后传播梯度

(1)中,L对xvar的求导将同时计算out1部分和out2部分;

(2)中,L对xvar的求导只计算out2部分,因为out1的requires_grad=False;

(3)中,L对xvar的求导只计算out1部分,因为out2的requires_grad=False;

验证如下:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed May 23 10:02:04 2018
@author: hy
"""
 
import torch
from torch.autograd import Variable
print("Pytorch version: {}".format(torch.__version__))
x=torch.Tensor([1])
xvar=Variable(x,requires_grad=True)
y1=torch.Tensor([2])
y2=torch.Tensor([7])
y1var=Variable(y1)
y2var=Variable(y2)
#(1)
print("For (1)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(2)
print("For (2)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
out1 = out1.detach()
print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(3)
print("For (3)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
#out1 = out1.detach()
out2 = out2.detach()
print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()

pytorch中,将变量的requires_grad设为False,即可让变量不参与梯度的后向传播;

但是不能直接将out1.requires_grad=False;

其实,Variable类型提供了detach()方法,所返回变量的requires_grad为False。

注意:如果out1和out2的requires_grad都为False的话,那么xvar.grad就出错了,因为梯度没有传到xvar

补充:

volatile=True表示这个变量不计算梯度, 参考:Volatile is recommended for purely inference mode, when you're sure you won't be even calling .backward(). It's more efficient than any other autograd setting - it will use the absolute minimal amount of memory to evaluate the model. volatile also determines that requires_grad is False.

以上这篇在pytorch中实现只让指定变量向后传播梯度就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程入门的一些基本知识
May 13 Python
Python base64编码解码实例
Jun 21 Python
Python解析最简单的验证码
Jan 07 Python
浅谈python 线程池threadpool之实现
Nov 17 Python
Python全排列操作实例分析
Jul 24 Python
python多进程控制学习小结
Oct 31 Python
python tkinter实现界面切换的示例代码
Jun 14 Python
Python字典对象实现原理详解
Jul 01 Python
浅谈Django QuerySet对象(模型.objects)的常用方法
Mar 28 Python
详解Django中的FBV和CBV对比分析
Mar 01 Python
python文本处理的方案(结巴分词并去除符号)
May 26 Python
Python基于百度API识别并提取图片中文字
Jun 27 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
Feb 29 #Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
You might like
php excel类 phpExcel使用方法介绍
2010/08/21 PHP
PHP的cURL库功能简介 抓取网页、POST数据及其他
2011/04/07 PHP
在php中判断一个请求是ajax请求还是普通请求的方法
2011/06/28 PHP
php截取html字符串及自动补全html标签的方法
2015/01/15 PHP
Yii2中SqlDataProvider用法示例
2016/09/22 PHP
PHP配置ZendOpcache插件加速
2019/02/14 PHP
php curl发送请求实例方法
2019/08/01 PHP
Javascript 类型转换方法
2010/10/24 Javascript
js代码实现的加入收藏效果并兼容主流浏览器
2014/06/23 Javascript
node.js中的events.emitter.removeAllListeners方法使用说明
2014/12/10 Javascript
如何用javascript计算文本框还能输入多少个字符
2015/07/29 Javascript
Jquery easyui 实现动态树
2015/11/17 Javascript
AngularJS整合Springmvc、Spring、Mybatis搭建开发环境
2016/02/25 Javascript
JavaScript模板引擎Template.js使用详解
2016/12/15 Javascript
关于vue.js v-bind 的一些理解和思考
2017/06/06 Javascript
vue实现nav导航栏的方法
2017/12/13 Javascript
vue多层嵌套路由实例分析
2019/03/19 Javascript
解决前后端分离 vue+springboot 跨域 session+cookie失效问题
2019/05/13 Javascript
基于JavaScript判断两个对象内容是否相等
2020/01/10 Javascript
python list 合并连接字符串的方法
2013/03/09 Python
Python多线程下载文件的方法
2015/07/10 Python
如何使用七牛Python SDK写一个同步脚本及使用教程
2015/08/23 Python
django静态文件加载的方法
2018/05/20 Python
python中pip的安装与使用教程
2018/08/10 Python
pycharm: 恢复(reset) 误删文件的方法
2018/10/22 Python
对Python2与Python3中__bool__方法的差异详解
2018/11/01 Python
CSS3教程(9):设置RGB颜色
2009/04/02 HTML / CSS
CSS+jQuery+PHP+MySQL实现的在线答题功能
2015/04/25 HTML / CSS
旺仔牛奶广告词
2014/03/20 职场文书
求职信范文怎么写
2015/03/19 职场文书
标枪加油稿
2015/07/22 职场文书
CSS3实现的水平标题菜单
2021/04/14 HTML / CSS
基于Python和openCV实现图像的全景拼接详细步骤
2021/10/05 Python
JavaScript事件的委托(代理)的用法示例详解
2022/02/18 Javascript
python基础之//、/与%的区别详解
2022/06/10 Python
使用opencv-python如何打开USB或者笔记本前置摄像头
2022/06/21 Python