基于TensorFlow中自定义梯度的2种方式


Posted in Python onFebruary 04, 2020

前言

在深度学习中,有时候我们需要对某些节点的梯度进行一些定制,特别是该节点操作不可导(比如阶梯除法如 基于TensorFlow中自定义梯度的2种方式 ),如果实在需要对这个节点进行操作,而且希望其可以反向传播,那么就需要对其进行自定义反向传播时的梯度。在有些场景,如[2]中介绍到的梯度反转(gradient inverse)中,就必须在某层节点对反向传播的梯度进行反转,也就是需要更改正常的梯度传播过程,如下图的 基于TensorFlow中自定义梯度的2种方式 所示。

基于TensorFlow中自定义梯度的2种方式

在tensorflow中有若干可以实现定制梯度的方法,这里介绍两种。

1. 重写梯度法

重写梯度法指的是通过tensorflow自带的机制,将某个节点的梯度重写(override),这种方法的适用性最广。我们这里举个例子[3].

符号函数的前向传播采用的是阶跃函数y=sign(x) y = \rm{sign}(x)y=sign(x),如下图所示,我们知道阶跃函数不是连续可导的,因此我们在反向传播时,将其替代为一个可以连续求导的函数y=Htanh(x) y = \rm{Htanh(x)}y=Htanh(x),于是梯度就是大于1和小于-1时为0,在-1和1之间时是1。

基于TensorFlow中自定义梯度的2种方式

使用重写梯度的方法如下,主要是涉及到tf.RegisterGradient()和tf.get_default_graph().gradient_override_map(),前者注册新的梯度,后者重写图中具有名字name='Sign'的操作节点的梯度,用在新注册的QuantizeGrad替代。

#使用修饰器,建立梯度反向传播函数。其中op.input包含输入值、输出值,grad包含上层传来的梯度
@tf.RegisterGradient("QuantizeGrad")
def sign_grad(op, grad):
 input = op.inputs[0] # 取出当前的输入
 cond = (input>=-1)&(input<=1) # 大于1或者小于-1的值的位置
 zeros = tf.zeros_like(grad) # 定义出0矩阵用于掩膜
 return tf.where(cond, grad, zeros) 
 # 将大于1或者小于-1的上一层的梯度置为0
 
#使用with上下文管理器覆盖原始的sign梯度函数
def binary(input):
 x = input
 with tf.get_default_graph().gradient_override_map({"Sign":'QuantizeGrad'}):
 #重写梯度
  x = tf.sign(x)
 return x
 
#使用
x = binary(x)

其中的def sign_grad(op, grad):是注册新的梯度的套路,其中的op是当前操作的输入值/张量等,而grad指的是从反向而言的上一层的梯度。

通常来说,在tensorflow中自定义梯度,函数tf.identity()是很重要的,其API手册如下:

tf.identity(
 input,
 name=None
)

其会返回一个形状和内容都和输入完全一样的输出,但是你可以自定义其反向传播时的梯度,因此在梯度反转等操作中特别有用。

这里再举个反向梯度[2]的例子,也就是梯度为 基于TensorFlow中自定义梯度的2种方式 而不是 基于TensorFlow中自定义梯度的2种方式

import tensorflow as tf
x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)
@tf.RegisterGradient('CustomGrad')
def CustomGrad(op, grad):
#  tf.Print(grad)
 return -grad
 
g = tf.get_default_graph()
oo = x1+x2
with g.gradient_override_map({"Identity": "CustomGrad"}):
 output = tf.identity(oo)
grad_1 = tf.gradients(output, oo)
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 print(sess.run(grad_1))

因为-grad,所以这里的梯度输出是[-1]而不是[1]。有一个我们需要注意的是,在自定义函数def CustomGrad()中,返回的值得是一个张量,而不能返回一个参数,比如return 0,这样会报错,如:

AttributeError: 'int' object has no attribute 'name'

显然,这是因为tensorflow的内部操作需要取返回值的名字而int类型没有名字。

PS:def CustomGrad()这个函数签名是随便你取的。

2. stop_gradient法

对于自定义梯度,还有一种比较简洁的操作,就是利用tf.stop_gradient()函数,我们看下例子[1]:

t = g(x)
y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是f(x),但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

我们看下完整的例子:

import tensorflow as tf

x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)

f = x1+x2*x3
t = -f

y1 = t + tf.stop_gradient(f-t)
y2 = f

grad_1 = tf.gradients(y1, x1)
grad_2 = tf.gradients(y2, x1)
with tf.Session(config=config) as sess:
 sess.run(tf.global_variables_initializer())

 print(sess.run(grad_1))
 print(sess.run(grad_2))

第一个输出为[-1],第二个输出为[1],显然也实现了梯度的反转。

以上这篇基于TensorFlow中自定义梯度的2种方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pycharm 使用心得(一)安装和首次使用
Jun 05 Python
跟老齐学Python之??碌某?? target=
Sep 12 Python
Python引用模块和查找模块路径
Mar 17 Python
python爬虫入门教程--快速理解HTTP协议(一)
May 25 Python
使用python实现tcp自动重连
Jul 02 Python
python下载图片实现方法(超简单)
Jul 21 Python
详解python使用Nginx和uWSGI来运行Python应用
Jan 09 Python
python处理csv数据动态显示曲线实例代码
Jan 23 Python
利用pandas将numpy数组导出生成excel的实例
Jun 14 Python
pycharm 在windows上编辑代码用linux执行配置的方法
Oct 27 Python
Python日期格式和字符串格式相互转换的方法
Feb 18 Python
Pycharm中安装wordcloud等库失败问题及终端通过pip安装的Python库如何添加到Pycharm解释器中(推荐)
May 10 Python
tensorflow 查看梯度方式
Feb 04 #Python
opencv python图像梯度实例详解
Feb 04 #Python
TensorFlow设置日志级别的几种方式小结
Feb 04 #Python
Python 实现加密过的PDF文件转WORD格式
Feb 04 #Python
解决tensorflow打印tensor有省略号的问题
Feb 04 #Python
对Tensorflow中tensorboard日志的生成与显示详解
Feb 04 #Python
在 Python 中接管键盘中断信号的实现方法
Feb 04 #Python
You might like
PHP无敌近乎加密方式!
2010/07/17 PHP
解析PHP计算页面执行时间的实现代码
2013/06/18 PHP
Linux下PHP加速器APC的安装与配置笔记
2014/10/24 PHP
PHP中SESSION过期设置
2021/03/09 PHP
需要做特殊处理的DOM元素属性的访问
2010/11/05 Javascript
jQuery实现购物车多物品数量的加减+总价计算
2014/06/06 Javascript
jQuery遍历json中多个map的方法
2015/02/12 Javascript
Javascript实现商品秒杀倒计时(时间与服务器时间同步)
2015/09/16 Javascript
nodejs微信扫码支付功能实现
2018/02/17 NodeJs
vue中实现图片和文件上传的示例代码
2018/03/16 Javascript
详谈js的变量提升以及使用方法
2018/10/06 Javascript
基于Vue实现可以拖拽的树形表格实例详解
2018/10/18 Javascript
elementUI select组件使用及注意事项详解
2019/05/29 Javascript
Vue组件间通信 Vuex的用法解析
2019/08/05 Javascript
Vue Extends 扩展选项用法完整实例
2019/09/17 Javascript
vue跳转页面的几种方法(推荐)
2020/03/26 Javascript
vue 实现基础组件的自动化全局注册
2020/12/25 Vue.js
[03:36]DOTA2完美大师赛coL战队趣味视频——我演你猜
2017/11/23 DOTA
[00:35]可解锁地面特效
2018/12/20 DOTA
[01:46]2020完美世界全国高校联赛秋季赛报名开启
2020/10/15 DOTA
[36:05]完美世界DOTA2联赛循环赛 Forest vs DM 第一场 11.06
2020/11/06 DOTA
采用Psyco实现python执行速度提高到与编译语言一样的水平
2014/10/11 Python
在Python中使用HTMLParser解析HTML的教程
2015/04/29 Python
Python中optparser库用法实例详解
2018/01/26 Python
查看python安装路径及pip安装的包列表及路径
2019/04/03 Python
几个解决兼容IE6\7\8不支持html5标签的几个方法
2013/01/07 HTML / CSS
世界上最大的专业美容用品零售商:Sally Beauty
2017/07/02 全球购物
加拿大当代时尚服饰、配饰和鞋类专业零售商和制造商:LE CHÂTEAU
2017/10/06 全球购物
个人简历自我评价
2014/01/06 职场文书
4s店销售经理岗位职责
2014/07/19 职场文书
小学趣味运动会加油稿
2014/09/25 职场文书
致百米运动员广播稿5篇
2014/10/13 职场文书
2014年教研组工作总结
2014/11/26 职场文书
幼儿园大班教师随笔
2015/08/14 职场文书
2019最新校园运动会广播稿!
2019/06/28 职场文书
创业计划书之干洗店
2019/09/10 职场文书