Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的装饰器用法详解
Jan 14 Python
深入讲解Python中面向对象编程的相关知识
May 25 Python
轻松掌握python设计模式之访问者模式
Nov 18 Python
python实现的AES双向对称加密解密与用法分析
May 02 Python
python使用Pycharm创建一个Django项目
Mar 05 Python
python之DataFrame实现excel合并单元格
Feb 22 Python
Python实现动态添加属性和方法操作示例
Jul 25 Python
python对文件目录的操作方法实例总结
Jun 24 Python
python使用Qt界面以及逻辑实现方法
Jul 10 Python
Pycharm 文件更改目录后,执行路径未更新的解决方法
Jul 19 Python
使用python代码进行身份证号校验的实现示例
Nov 21 Python
Python版中国省市经纬度
Feb 11 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
用Apache反向代理设置对外的WWW和文件服务器
2006/10/09 PHP
用PHP查询域名状态whois的类
2006/11/25 PHP
php+html5实现无刷新图片上传教程
2016/01/22 PHP
Yii2表单事件之Ajax提交实现方法
2017/05/04 PHP
Laravel中获取路由参数Route Parameters的五种方法示例
2017/09/29 PHP
纯CSS3实现质感细腻丝滑按钮
2021/03/09 HTML / CSS
jquery增加时编辑jqGrid(实例代码)
2013/11/08 Javascript
firefox下jquery ajax返回object XMLDocument处理方法
2014/01/26 Javascript
js生成随机数的过程解析
2015/11/24 Javascript
Js实现京东无延迟菜单效果实例(demo)
2017/06/02 Javascript
基于js 各种排序方法和sort方法的区别(详解)
2018/01/03 Javascript
JS简单获得节点元素的方法示例
2018/02/10 Javascript
JavaScript门道之标准库
2018/05/26 Javascript
nodejs实现一个word文档解析器思路详解
2018/08/14 NodeJs
[03:17]DOTA2英雄基础教程 剧毒术士
2013/12/12 DOTA
[05:09]第二届DOTA2亚洲邀请赛决赛日比赛集锦:iG 3:0 OG夺冠
2017/04/05 DOTA
Python实现扫描局域网活动ip(扫描在线电脑)
2015/04/28 Python
详解Python的Django框架中的中间件
2015/07/24 Python
python使用arcpy.mapping模块批量出图
2017/03/06 Python
老生常谈python之鸭子类和多态
2017/06/13 Python
Python基于递归实现电话号码映射功能示例
2018/04/13 Python
python检索特定内容的文本文件实例
2018/06/05 Python
python使用多线程编写tcp客户端程序
2019/09/02 Python
pygame实现成语填空游戏
2019/10/29 Python
Python 静态方法和类方法实例分析
2019/11/21 Python
python3 xpath和requests应用详解
2020/03/06 Python
keras处理欠拟合和过拟合的实例讲解
2020/05/25 Python
Python-split()函数实例用法讲解
2020/12/18 Python
css3实现一个div设置多张背景图片及background-image属性实例演示
2017/08/10 HTML / CSS
html2canvas生成的图片偏移不完整的解决方法
2020/05/19 HTML / CSS
Lands’ End英国官方网站:高质量男女服装
2017/10/07 全球购物
神话般的珠宝:Ross-Simons
2020/07/13 全球购物
自我评价优秀范文分享
2013/11/30 职场文书
甜品店创业计划书
2014/09/21 职场文书
2015年初中教务处工作总结
2015/07/21 职场文书
Jackson 反序列化时实现大小写不敏感设置
2021/06/29 Java/Android