Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用简单工厂模式来进行Python的设计模式编程
Mar 01 Python
python之django母板页面的使用
Jul 03 Python
Python读写文件基础知识点
Jun 10 Python
python opencv 批量改变图片的尺寸大小的方法
Jun 28 Python
python区块及区块链的开发详解
Jul 03 Python
Python 进程操作之进程间通过队列共享数据,队列Queue简单示例
Oct 11 Python
Python 中使用 PyMySQL模块操作数据库的方法
Nov 10 Python
python使用PIL剪切和拼接图片
Mar 23 Python
python3.4中清屏的处理方法
Jul 06 Python
用python进行视频剪辑
Nov 02 Python
Pycharm同步远程服务器调试的方法步骤
Nov 04 Python
python爬虫请求头的使用
Dec 01 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
浅析SVN常见问题及解决方法
2013/06/21 PHP
一个显示效果非常不错的PHP错误、异常处理类
2014/03/21 PHP
TP5框架model常见操作示例小结【增删改查、聚合、时间戳、软删除等】
2020/04/05 PHP
图片自动缩小的js代码,用以防止图片撑破页面
2007/03/12 Javascript
jquery 无限级联菜单案例分享
2013/03/26 Javascript
前端轻量级MVC框架CanJS详解
2014/09/26 Javascript
使用JavaScript实现连续滚动字幕效果的方法
2015/07/07 Javascript
jQuery遍历json的方法(推荐)
2016/06/12 Javascript
jquery 判断div show的状态实例
2016/12/03 Javascript
Angular指令封装jQuery日期时间插件datetimepicker实现双向绑定示例
2017/01/22 Javascript
JavaScript mixin实现多继承的方法详解
2017/03/30 Javascript
select自定义小三角样式代码(实用总结)
2017/08/18 Javascript
JavaScript实现各种排序的代码详解
2017/08/28 Javascript
ES6下子组件调用父组件的方法(推荐)
2018/02/23 Javascript
详解关于微信setData回调函数中的坑
2019/02/18 Javascript
浅谈vuex中store的命名空间
2019/11/08 Javascript
layui数据表格重载实现往后台传参
2019/11/15 Javascript
利用python GDAL库读写geotiff格式的遥感影像方法
2018/11/29 Python
python实现高斯(Gauss)迭代法的例子
2019/11/20 Python
TensorFlow实现指数衰减学习率的方法
2020/02/05 Python
python统计字符的个数代码实例
2020/02/07 Python
Python字符串函数strip()原理及用法详解
2020/07/23 Python
python中函数返回多个结果的实例方法
2020/12/16 Python
css3动画过渡实现鼠标跟随导航效果
2018/02/08 HTML / CSS
详解Html5 Canvas画线有毛边解决方法
2018/03/01 HTML / CSS
通过Canvas及File API缩放并上传图片完整示例
2013/08/08 HTML / CSS
Vinatis德国:法国领先的葡萄酒邮购公司
2020/09/07 全球购物
中级会计职业生涯规划范文
2014/01/16 职场文书
宿舍保安职务说明书
2014/02/25 职场文书
校园安全演讲稿
2014/05/09 职场文书
英文求职信范文
2014/05/23 职场文书
事业单位财务人员岗位职责
2015/04/14 职场文书
初中班级口号霸气押韵
2015/12/24 职场文书
分位数回归模型quantile regeression应用详解及示例教程
2021/11/02 Python
Java9新特性之Module模块化编程示例演绎
2022/03/16 Java/Android
Springboot集成kafka高级应用实战分享
2022/08/14 Java/Android