Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的迭代器和生成器使用实例
Jan 14 Python
Python实现查找匹配项作处理后再替换回去的方法
Jun 10 Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 Python
Python装饰器知识点补充
May 28 Python
Django数据库连接丢失问题的解决方法
Dec 29 Python
Python socket实现多对多全双工通信的方法
Feb 13 Python
解决python线程卡死的问题
Feb 18 Python
PyQt5 QListWidget选择多项并返回的实例
Jun 17 Python
python GUI库图形界面开发之PyQt5多行文本框控件QTextEdit详细使用方法实例
Feb 28 Python
python 弧度与角度互转实例
Apr 15 Python
win10下python3.8的PIL库安装过程
Jun 08 Python
matplotlib部件之矩形选区(RectangleSelector)的实现
Feb 01 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
PHP的开合式多级菜单程序
2006/10/09 PHP
用PHPdig打造属于你自己的Google[图文教程]
2007/02/14 PHP
php中json_decode()和json_encode()的使用方法
2012/06/04 PHP
laravel利用中间件防止未登录用户直接访问后台的方法
2019/09/30 PHP
如何在PHP中使用数组
2020/06/09 PHP
一个JavaScript获取元素当前高度的实例
2014/10/29 Javascript
jQuery点缩略图弹出层显示大图片
2015/02/13 Javascript
使用jQuery在对象中缓存选择器的简单方法
2015/06/30 Javascript
javascript每日必学之运算符
2016/02/16 Javascript
Json解析的方法小结
2016/06/22 Javascript
JavaScript箭头(arrow)函数详解
2017/06/04 Javascript
详解webpack+gulp实现自动构建部署
2017/06/29 Javascript
JS中this的指向以及call、apply的作用
2018/05/06 Javascript
nodejs对项目下所有空文件夹创建gitkeep的方法
2019/08/02 NodeJs
菜鸟使用python实现正则检测密码合法性
2016/01/05 Python
Python基础教程之浅拷贝和深拷贝实例详解
2017/07/15 Python
python发送邮件实例分享
2017/07/28 Python
如何在sae中设置django,让sae的工作环境跟本地python环境一致
2017/11/21 Python
python 地图经纬度转换、纠偏的实例代码
2018/08/06 Python
python随机在一张图像上截取任意大小图片的方法
2019/01/24 Python
python set内置函数的具体使用
2019/07/02 Python
Django stark组件使用及原理详解
2019/08/22 Python
利用python实现逐步回归
2020/02/24 Python
浅谈django框架集成swagger以及自定义参数问题
2020/07/07 Python
Python matplotlib图例放在外侧保存时显示不完整问题解决
2020/07/28 Python
详解HTML5中rel属性的prefetch预加载功能使用
2016/05/06 HTML / CSS
制衣厂各岗位职责
2013/12/02 职场文书
外语专业毕业生自荐信
2014/04/14 职场文书
机关单位工作失职检讨书
2014/11/20 职场文书
2014年初级职称工作总结
2014/12/08 职场文书
领导欢迎词范文
2015/01/26 职场文书
小学重阳节活动总结
2015/03/24 职场文书
毕业设计工作总结
2015/08/14 职场文书
中秋节英文祝福语句(14句)
2019/09/11 职场文书
Mysql 如何查询时间段交集
2021/06/08 MySQL
试了下Golang实现try catch的方法
2021/07/01 Golang