Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入浅析python with语句简介
Apr 11 Python
python实现pdf转换成word/txt纯文本文件
Jun 07 Python
Python装饰器简单用法实例小结
Dec 03 Python
搞清楚 Python traceback的具体使用方法
May 13 Python
Python秒算24点实现及原理详解
Jul 29 Python
PyTorch和Keras计算模型参数的例子
Jan 02 Python
PyTorch中的Variable变量详解
Jan 07 Python
python2 对excel表格操作完整示例
Feb 23 Python
Python使用Socket实现简单聊天程序
Feb 28 Python
Python子进程subpocess原理及用法解析
Jul 16 Python
python简单实现插入排序实例代码
Dec 16 Python
利用python为PostgreSQL的表自动添加分区
Jan 18 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
PHP合并数组+与array_merge的区别分析
2010/08/01 PHP
php设计模式 Prototype (原型模式)代码
2011/06/26 PHP
ThinkPHP模板Switch标签用法示例
2014/06/30 PHP
php中文字符串截取方法实例总结
2014/09/30 PHP
PHP基于工厂模式实现的计算器实例
2015/07/16 PHP
PHP7之Mongodb API使用详解
2015/12/26 PHP
PHP中字符串长度的截取用法示例
2017/01/12 PHP
浅谈PHP中的面向对象OOP中的魔术方法
2017/06/12 PHP
javascript操作cookie_获取与修改代码
2009/05/21 Javascript
Javascript中封装window.open解决不兼容问题
2014/09/28 Javascript
浅谈JavaScript中的String对象常用方法
2015/02/25 Javascript
举例简介AngularJS的内部语言环境
2015/06/17 Javascript
HTML Table 空白单元格补全的简单实现
2016/10/13 Javascript
Bootstrap Table使用心得总结
2016/11/29 Javascript
Bootstrap提示框效果的实例代码
2017/07/12 Javascript
让nodeJS支持ES6的词法----babel的安装和使用方法
2017/07/31 NodeJs
JavaScript中错误正确处理方式小结你用对了吗
2017/10/10 Javascript
仿京东快报向上滚动的实例
2017/12/13 Javascript
JS实现键值对遍历json数组功能示例
2018/05/30 Javascript
python实现随机密码字典生成器示例
2014/04/09 Python
使用IPython下的Net-SNMP来管理类UNIX系统的教程
2015/04/15 Python
Python的GUI框架PySide的安装配置教程
2016/02/16 Python
Python基于回溯法解决01背包问题实例
2017/12/06 Python
Python之list对应元素求和的方法
2018/06/28 Python
Python实现曲线拟合操作示例【基于numpy,scipy,matplotlib库】
2018/07/12 Python
python算法与数据结构之冒泡排序实例详解
2019/06/22 Python
python+selenium select下拉选择框定位处理方法
2019/08/24 Python
Python 使用 Pillow 模块给图片添加文字水印的方法
2019/08/30 Python
pyqt5 QlistView列表显示的实现示例
2020/03/24 Python
大学生求职推荐信
2013/11/27 职场文书
英语自荐信常用语句
2013/12/13 职场文书
音乐幼师求职信
2014/07/09 职场文书
农业生产宣传标语
2014/10/08 职场文书
创业计划书之婴幼儿游泳馆
2019/09/11 职场文书
html2 canvas svg不能识别的解决方案
2021/06/03 HTML / CSS
如何理解python接口自动化之logging日志模块
2021/06/15 Python