在pytorch中实现只让指定变量向后传播梯度


Posted in Python onFebruary 29, 2020

pytorch中如何只让指定变量向后传播梯度?

(或者说如何让指定变量不参与后向传播?)

有以下公式,假如要让L对xvar求导:

在pytorch中实现只让指定变量向后传播梯度

(1)中,L对xvar的求导将同时计算out1部分和out2部分;

(2)中,L对xvar的求导只计算out2部分,因为out1的requires_grad=False;

(3)中,L对xvar的求导只计算out1部分,因为out2的requires_grad=False;

验证如下:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed May 23 10:02:04 2018
@author: hy
"""
 
import torch
from torch.autograd import Variable
print("Pytorch version: {}".format(torch.__version__))
x=torch.Tensor([1])
xvar=Variable(x,requires_grad=True)
y1=torch.Tensor([2])
y2=torch.Tensor([7])
y1var=Variable(y1)
y2var=Variable(y2)
#(1)
print("For (1)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(2)
print("For (2)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
out1 = out1.detach()
print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(3)
print("For (3)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
#out1 = out1.detach()
out2 = out2.detach()
print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()

pytorch中,将变量的requires_grad设为False,即可让变量不参与梯度的后向传播;

但是不能直接将out1.requires_grad=False;

其实,Variable类型提供了detach()方法,所返回变量的requires_grad为False。

注意:如果out1和out2的requires_grad都为False的话,那么xvar.grad就出错了,因为梯度没有传到xvar

补充:

volatile=True表示这个变量不计算梯度, 参考:Volatile is recommended for purely inference mode, when you're sure you won't be even calling .backward(). It's more efficient than any other autograd setting - it will use the absolute minimal amount of memory to evaluate the model. volatile also determines that requires_grad is False.

以上这篇在pytorch中实现只让指定变量向后传播梯度就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python NumPy库安装使用笔记
May 18 Python
Python实现建立SSH连接的方法
Jun 03 Python
python用pickle模块实现“增删改查”的简易功能
Jun 07 Python
python MNIST手写识别数据调用API的方法
Aug 08 Python
Python使用pydub库对mp3与wav格式进行互转的方法
Jan 10 Python
django的分页器Paginator 从django中导入类
Jul 25 Python
Python调用C语言的实现
Jul 26 Python
用Python徒手撸一个股票回测框架搭建【推荐】
Aug 05 Python
18个Python脚本可加速你的编码速度(提示和技巧)
Oct 17 Python
Python 解决OPEN读文件报错 ,路径以及r的问题
Dec 19 Python
python颜色随机生成器的实例代码
Jan 10 Python
python爬虫快速响应服务器的做法
Nov 24 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
Feb 29 #Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
You might like
php 记录进行累加并显示总时长为秒的结果
2011/11/04 PHP
php使用Session和文件统计在线人数
2015/07/04 PHP
JQuery最佳实践之精妙的自定义事件
2010/08/11 Javascript
jQuery EasyUI API 中文文档 - NumberBox数字框
2011/10/13 Javascript
ParseInt函数参数设置介绍
2014/01/02 Javascript
jquery css 设置table的奇偶行背景色示例
2014/06/03 Javascript
angularjs实现与服务器交互分享
2014/06/24 Javascript
js实现点击图片改变页面背景图的方法
2015/02/28 Javascript
js实现动态创建的元素绑定事件
2016/07/19 Javascript
详解angularJs中自定义directive的数据交互
2017/01/13 Javascript
详解如何用webpack打包一个网站应用项目
2017/07/12 Javascript
javascript  数组排序与对象排序的实例
2017/07/17 Javascript
解决ionic和angular上拉加载的问题
2017/08/03 Javascript
vux uploader 图片上传组件的安装使用方法
2018/05/15 Javascript
纯异步nodejs文件夹(目录)复制功能
2019/09/03 NodeJs
VUE:vuex 用户登录信息的数据写入与获取方式
2019/11/11 Javascript
Python实现的监测服务器硬盘使用率脚本分享
2014/11/07 Python
Python MySQLdb模块连接操作mysql数据库实例
2015/04/08 Python
总结Python编程中三条常用的技巧
2015/05/11 Python
Python 3.x 连接数据库示例(pymysql 方式)
2017/01/19 Python
python机器学习之神经网络(一)
2017/12/20 Python
python批量复制图片到另一个文件夹
2018/09/17 Python
Python格式化输出字符串方法小结【%与format】
2018/10/29 Python
python获取交互式ssh shell的方法
2019/02/14 Python
python plotly画柱状图代码实例
2019/12/13 Python
python如何建立全零数组
2020/07/19 Python
Python爬虫实现自动登录、签到功能的代码
2020/08/20 Python
英国最大的手表网站:The Watch Hut
2017/03/31 全球购物
英国卫浴商店:Ergonomic Design
2019/09/22 全球购物
中专毕业个人的自荐信格式
2013/09/21 职场文书
行政管理专业推荐信
2013/11/02 职场文书
闭幕式主持词
2014/04/02 职场文书
水电站项目建议书
2014/05/12 职场文书
学校标语大全
2014/06/19 职场文书
商务代表岗位职责
2015/02/15 职场文书
MySQL控制流函数(-if ,elseif,else,case...when)
2022/07/07 MySQL