Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3处理文件中每个词的方法
May 22 Python
python实现linux下使用xcopy的方法
Jun 28 Python
python的random模块及加权随机算法的python实现方法
Jan 04 Python
Python实现完整的事务操作示例
Jun 20 Python
微信跳一跳游戏python脚本
Apr 01 Python
Python 实现加密过的PDF文件转WORD格式
Feb 04 Python
通过python连接Linux命令行代码实例
Feb 18 Python
Python使用re模块验证危险字符
May 21 Python
python怎么删除缓存文件
Jul 19 Python
Python 实现PS滤镜的旋涡特效
Dec 03 Python
Python办公自动化解决world文件批量转换
Sep 15 Python
Python 阶乘详解
Oct 05 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
php中突破基于HTTP_REFERER的防盗链措施(stream_context_create)
2011/03/29 PHP
PHP文件去掉PHP注释空格的函数分析(PHP代码压缩)
2013/07/02 PHP
Yii2框架数据库简单的增删改查语法小结
2016/08/31 PHP
Yii框架实现记录日志到自定义文件的方法
2017/05/23 PHP
php 删除一维数组中某一个值元素的操作方法
2018/02/01 PHP
php使用curl获取header检测开启GZip压缩的方法
2018/08/15 PHP
使用laravel和ajax实现整个页面无刷新的操作方法
2019/10/03 PHP
javascript IE中的DOM ready应用技巧
2008/07/23 Javascript
JavaScript设计模式之外观模式实例
2014/10/10 Javascript
jQuery幻灯片特效代码分享--鼠标滑过按钮时切换(2)
2020/11/18 Javascript
分享12个非常实用的JavaScript小技巧
2016/05/11 Javascript
浅谈关于axios和session的一些事
2017/07/13 Javascript
详解关于react-redux中的connect用法介绍及原理解析
2017/09/11 Javascript
json2.js 入门教程之使用方法与实例分析
2017/09/14 Javascript
vue弹窗消息组件的使用方法
2020/09/24 Javascript
Vue头像处理方案小结
2018/07/26 Javascript
JS面试题大坑之隐式类型转换实例代码
2018/10/14 Javascript
vue中el-upload上传图片到七牛的示例代码
2018/10/19 Javascript
jquery自定义组件实例详解
2020/12/31 jQuery
python导出chrome书签到markdown文件的实例代码
2017/12/27 Python
python读写csv文件方法详细总结
2019/07/05 Python
python 进程 进程池 进程间通信实现解析
2019/08/23 Python
基于python使用tibco ems代码实例
2019/12/20 Python
pytorch动态网络以及权重共享实例
2020/01/06 Python
Windows下Sqlmap环境安装教程详解
2020/08/04 Python
goodhealth官方海外旗舰店:新西兰国民营养师
2017/12/15 全球购物
西班牙在线宠物食品和配件商店:bitiba
2019/10/11 全球购物
艺术系应届生的自我评价
2013/10/19 职场文书
会计电算化大学生职业规划书
2014/02/05 职场文书
幼儿园教师节演讲稿
2014/09/03 职场文书
护士工作失误检讨书
2014/09/14 职场文书
乡镇党的群众路线教育实践活动剖析材料
2014/10/09 职场文书
优秀党员先进材料
2014/12/18 职场文书
三八节活动简报
2015/07/20 职场文书
话题作文之自信作文
2019/11/15 职场文书
python使用matplotlib绘制图片时x轴的刻度处理
2021/08/30 Python