Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例讲解Python的Tornado框架实现数据可视化的教程
May 02 Python
栈和队列数据结构的基本概念及其相关的Python实现
Aug 24 Python
Python用模块pytz来转换时区
Aug 19 Python
浅谈django开发者模式中的autoreload是如何实现的
Aug 18 Python
Python实现基于二叉树存储结构的堆排序算法示例
Dec 08 Python
python3+PyQt5泛型委托详解
Apr 24 Python
详解Django解决ajax跨域访问问题
Aug 24 Python
python使用xlrd和xlwt读写Excel文件的实例代码
Sep 05 Python
windows 10 设定计划任务自动执行 python 脚本的方法
Sep 11 Python
pytest中文文档之编写断言
Sep 12 Python
Python3读写Excel文件(使用xlrd,xlsxwriter,openpyxl3种方式读写实例与优劣)
Feb 13 Python
M1芯片安装python3.9.1的实现
Feb 02 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
ecshop 订单确认中显示省市地址信息的方法
2010/03/15 PHP
php版微信支付api.mch.weixin.qq.com域名解析慢原因与解决方法
2016/10/12 PHP
PHP生成随机密码4种方法及性能对比
2020/12/11 PHP
Area 区域实现post提交数据的js写法
2014/04/22 Javascript
javascript实现图片自动和可控的轮播切换特效
2015/04/13 Javascript
jQuery+AJAX实现遮罩层登录验证界面(附源码)
2020/09/13 Javascript
JS中Json数据的处理和解析JSON数据的方法详解
2016/06/29 Javascript
浅谈js内置对象Math的属性和方法(推荐)
2016/09/19 Javascript
javascript对浅拷贝和深拷贝的详解
2016/10/14 Javascript
AngularJs表单验证实例代码解析
2016/11/29 Javascript
js中常用的Math方法总结
2017/01/12 Javascript
微信小程序-小说阅读小程序实例(demo)
2017/01/12 Javascript
ES6新特性之字符串的扩展实例分析
2017/04/01 Javascript
详解Vue2.0里过滤器容易踩到的坑
2017/06/01 Javascript
JavaScript中错误正确处理方式小结你用对了吗
2017/10/10 Javascript
Vue官方文档梳理之全局配置
2017/11/22 Javascript
vue click.stop阻止点击事件继续传播的方法
2018/09/04 Javascript
vue-cli脚手架build目录下utils.js工具配置文件详解
2018/09/14 Javascript
使用ThinkJs搭建微信中控服务的实现方法
2019/08/08 Javascript
JS+HTML5本地存储Localstorage实现注册登录及验证功能示例
2020/02/10 Javascript
[02:09]DOTA2辉夜杯 EHOME夺冠举杯现场
2015/12/28 DOTA
使用Python对MySQL数据操作
2017/04/06 Python
Python对象类型及其运算方法(详解)
2017/07/05 Python
python输入错误密码用户锁定实现方法
2017/11/27 Python
python并发编程多进程 互斥锁原理解析
2019/08/20 Python
会计学财务管理专业个人的自我评价
2013/10/19 职场文书
爱国卫生月实施方案
2014/02/21 职场文书
公司节能减排倡议书
2014/05/14 职场文书
报效祖国演讲稿
2014/09/15 职场文书
玩手机检讨书1000字
2014/10/20 职场文书
党的群众路线教育实践活动心得体会(医院)
2014/11/03 职场文书
教师学期末个人总结
2015/02/13 职场文书
公司宣传语大全
2015/07/13 职场文书
Redis集群新增、删除节点以及动态增加内存的方法
2021/09/04 Redis
《艾尔登法环》1.03.3补丁上线 碎星伤害调整
2022/04/06 其他游戏
Nginx+Tomcat负载均衡多实例详解
2022/04/11 Servers