在pytorch中实现只让指定变量向后传播梯度


Posted in Python onFebruary 29, 2020

pytorch中如何只让指定变量向后传播梯度?

(或者说如何让指定变量不参与后向传播?)

有以下公式,假如要让L对xvar求导:

在pytorch中实现只让指定变量向后传播梯度

(1)中,L对xvar的求导将同时计算out1部分和out2部分;

(2)中,L对xvar的求导只计算out2部分,因为out1的requires_grad=False;

(3)中,L对xvar的求导只计算out1部分,因为out2的requires_grad=False;

验证如下:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed May 23 10:02:04 2018
@author: hy
"""
 
import torch
from torch.autograd import Variable
print("Pytorch version: {}".format(torch.__version__))
x=torch.Tensor([1])
xvar=Variable(x,requires_grad=True)
y1=torch.Tensor([2])
y2=torch.Tensor([7])
y1var=Variable(y1)
y2var=Variable(y2)
#(1)
print("For (1)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(2)
print("For (2)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
out1 = out1.detach()
print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(3)
print("For (3)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
#out1 = out1.detach()
out2 = out2.detach()
print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()

pytorch中,将变量的requires_grad设为False,即可让变量不参与梯度的后向传播;

但是不能直接将out1.requires_grad=False;

其实,Variable类型提供了detach()方法,所返回变量的requires_grad为False。

注意:如果out1和out2的requires_grad都为False的话,那么xvar.grad就出错了,因为梯度没有传到xvar

补充:

volatile=True表示这个变量不计算梯度, 参考:Volatile is recommended for purely inference mode, when you're sure you won't be even calling .backward(). It's more efficient than any other autograd setting - it will use the absolute minimal amount of memory to evaluate the model. volatile also determines that requires_grad is False.

以上这篇在pytorch中实现只让指定变量向后传播梯度就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的石头剪子布代码分享
Aug 22 Python
python使用PyGame绘制图像并保存为图片文件的方法
Apr 24 Python
python的Crypto模块实现AES加密实例代码
Jan 22 Python
利用python实现短信和电话提醒功能的例子
Aug 08 Python
Django对models里的objects的使用详解
Aug 17 Python
python飞机大战pygame游戏框架搭建操作详解
Dec 17 Python
python自动化unittest yaml使用过程解析
Feb 03 Python
详解Python的三种拷贝方式
Feb 11 Python
使用python创建生成动态链接库dll的方法
May 09 Python
基于opencv的selenium滑动验证码的实现
Jul 24 Python
python 解决selenium 中的 .clear()方法失效问题
Sep 01 Python
如何使用Python实现一个简易的ORM模型
May 12 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
Feb 29 #Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
You might like
第四节 构造函数和析构函数 [4]
2006/10/09 PHP
PHP PDO函数库详解
2010/04/27 PHP
JavaScript版代码高亮
2006/06/26 Javascript
flexigrid 参数说明
2010/11/23 Javascript
5秒后跳转效果(setInterval/SetTimeOut)
2013/05/03 Javascript
jquery 倒计时效果实现秒杀思路
2013/09/11 Javascript
Egret引擎开发指南之发布项目
2014/09/03 Javascript
jQuery操作表单常用控件方法小结
2015/03/23 Javascript
编写高性能Javascript代码的N条建议
2015/10/12 Javascript
js实现打地鼠小游戏
2017/02/13 Javascript
webpack v4 从dev到prd的方法
2018/04/02 Javascript
JS实现的JSON序列化操作简单示例
2018/07/02 Javascript
JavaScript实现移动端带transition动画的轮播效果
2020/03/24 Javascript
jQuery实现移动端图片上传预览组件的方法分析
2020/05/01 jQuery
原生JS封装拖动验证滑块的实现代码示例
2020/06/01 Javascript
Vue element-ui父组件控制子组件的表单校验操作
2020/07/17 Javascript
基于脚手架创建Vue项目实现步骤详解
2020/08/03 Javascript
[03:01]2014DOTA2国际邀请赛 DC:我是核弹粉,为Burning和国土祝福
2014/07/13 DOTA
Mac中Python 3环境下安装scrapy的方法教程
2017/10/26 Python
python验证码识别实例代码
2018/02/03 Python
python使用tornado实现简单爬虫
2018/07/28 Python
python调用Matplotlib绘制分布点图
2019/10/18 Python
python打印直角三角形与等腰三角形实例代码
2019/10/20 Python
Python3搭建http服务器的实现代码
2020/02/11 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
2020/05/25 Python
Django-Scrapy生成后端json接口的方法示例
2020/10/06 Python
python如何编写类似nmap的扫描工具
2020/11/06 Python
一波HTML5 Canvas基础绘图实例代码集合
2016/02/28 HTML / CSS
html5 canvas简单封装一个echarts实现不了的饼图
2018/06/12 HTML / CSS
鼓励运动员的广播稿
2014/02/08 职场文书
社保委托书怎么写
2014/08/02 职场文书
涉外离婚协议书怎么写
2014/11/20 职场文书
商务英语邮件开头问候语
2015/11/10 职场文书
创业计划书之书店
2019/09/10 职场文书
Python基础之数据类型知识汇总
2021/05/18 Python
如何在Python中妥善使用进度条详解
2022/04/05 Python