Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现在线音乐播放器
Mar 03 Python
Python实现采用进度条实时显示处理进度的方法
Dec 19 Python
python操作列表的函数使用代码详解
Dec 28 Python
对python多线程中Lock()与RLock()锁详解
Jan 11 Python
Python Opencv实现图像轮廓识别功能
Mar 23 Python
python opencv捕获摄像头并显示内容的实现
Jul 11 Python
Python进阶之迭代器与迭代器切片教程
Jan 29 Python
pandas参数设置的实用小技巧
Aug 23 Python
Python+OpenCV检测灯光亮点的实现方法
Nov 02 Python
Python实现哲学家就餐问题实例代码
Nov 09 Python
Django跨域请求原理及实现代码
Nov 14 Python
Python实现小黑屋游戏的完整实例
Jan 06 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
CodeIgniter基本配置详细介绍
2013/11/12 PHP
php使用unset()删除数组中某个单元(键)的方法
2015/02/17 PHP
简单谈谈PHP面向对象之标识对象
2017/06/27 PHP
php文件操作之文件写入字符串、数组的方法分析
2019/04/15 PHP
浅析JavaScript基本类型与引用类型
2014/05/28 Javascript
简单的jquery左侧导航栏和页面选中效果
2014/08/21 Javascript
javascript模拟评分控件实现方法
2015/05/13 Javascript
JS实现按比例缩放图片的方法(附C#版代码)
2015/12/08 Javascript
AngularJS中处理多个promise的方式
2016/02/02 Javascript
javascript+HTML5 Canvas绘制转盘抽奖
2020/05/16 Javascript
JS事件添加和移出的兼容写法示例
2016/06/20 Javascript
微信小程序 在Chrome浏览器上运行以及WebStorm的使用
2016/09/27 Javascript
JS获得一个对象的所有属性和方法实例
2017/02/21 Javascript
Angular4 中内置指令的基本用法
2017/07/31 Javascript
详解node服务器中打开html文件的两种方法
2017/09/18 Javascript
AngularJS实现的select二级联动下拉菜单功能示例
2017/10/25 Javascript
基于nodejs实现微信支付功能
2017/12/20 NodeJs
js点击时关闭该范围下拉菜单之外的菜单方法
2018/01/11 Javascript
js实现小球在页面规定的区域运动
2020/06/16 Javascript
JavaScript实时更新当前的时间的示例代码
2020/07/15 Javascript
ReactRouter的实现方法
2021/01/25 Javascript
[原创]使用豆瓣提供的国内pypi源
2017/07/02 Python
python如何派生内置不可变类型并修改实例化行为
2018/03/21 Python
Python实现扣除个人税后的工资计算器示例
2018/03/26 Python
TensorFlow打印tensor值的实现方法
2018/07/27 Python
Python使用sklearn实现的各种回归算法示例
2019/07/04 Python
Django实现文件上传下载功能
2019/10/06 Python
鼠标滚轮事件和Mac触控板双指事件
2019/12/23 HTML / CSS
美国最便宜的旅游网站:CheapTickets
2017/07/09 全球购物
在Java开发中如何选择使用哪种集合类
2016/08/09 面试题
优秀共产党员先进事迹
2014/01/27 职场文书
刘公岛导游词
2015/02/05 职场文书
学雷锋团日活动总结
2015/05/06 职场文书
刑事附带民事代理词
2015/05/25 职场文书
Python import模块的缓存问题解决方案
2021/06/02 Python
仅仅使用 HTML/CSS 实现各类进度条的方式汇总
2021/11/11 HTML / CSS