Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python复制目录结构脚本代码分享
Mar 06 Python
Python实现比较两个列表(list)范围
Jun 12 Python
python压缩文件夹内所有文件为zip文件的方法
Jun 20 Python
在Python的Django框架中使用通用视图的方法
Jul 21 Python
Python pymongo模块用法示例
Mar 31 Python
解决matplotlib库show()方法不显示图片的问题
May 24 Python
python实现飞机大战
Sep 11 Python
对python中的argv和argc使用详解
Dec 15 Python
python制作填词游戏步骤详解
May 05 Python
selenium+PhantomJS爬取豆瓣读书
Aug 26 Python
PyTorch中的Variable变量详解
Jan 07 Python
win10+anaconda安装yolov5的方法及问题解决方案
Apr 29 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
解析PHP无限级分类方法及代码
2013/06/21 PHP
php获取mysql字段名称和其它信息的例子
2014/04/14 PHP
php根据数据id自动生成编号的实现方法
2016/10/16 PHP
CI框架常用函数封装实例
2016/11/21 PHP
php 三大特点:封装,继承,多态
2017/02/19 PHP
自写的利用PDO对mysql数据库增删改查操作类
2018/02/19 PHP
PHP从零开始打造自己的MVC框架之路由类实现方法分析
2019/06/03 PHP
js变量作用域及可访问性的探讨
2006/11/23 Javascript
js获取url中指定参数值的示例代码
2013/12/14 Javascript
Js实现动态添加删除Table行示例
2014/04/14 Javascript
给应用部分的js代码设定一个统一的入口
2014/06/15 Javascript
jquery事件preventDefault()方法用法实例
2015/01/16 Javascript
jquery插件jSignature实现手动签名
2015/05/04 Javascript
简介JavaScript中的italics()方法的使用
2015/06/08 Javascript
web前端开发JQuery常用实例代码片段(50个)
2015/08/28 Javascript
详解JS模块导入导出
2017/12/20 Javascript
微信小程序实现侧边栏分类
2019/10/21 Javascript
js实现限定范围拖拽的示例
2020/10/26 Javascript
[55:56]NB vs Infamous 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.22
2019/09/05 DOTA
python查看zip包中文件及大小的方法
2015/07/09 Python
你眼中的Python大牛 应该都有这份书单
2017/10/31 Python
使用python3+xlrd解析Excel的实例
2018/05/04 Python
python使用matplotlib画柱状图、散点图
2019/03/18 Python
TensorFlow 显存使用机制详解
2020/02/03 Python
HTML高亮关键字的实现代码
2018/10/22 HTML / CSS
HTML5实现签到 功能
2018/10/09 HTML / CSS
心理健康日活动总结
2014/05/08 职场文书
重阳节标语大全
2014/10/07 职场文书
2014年个人总结范文
2015/03/09 职场文书
党小组鉴定意见
2015/06/02 职场文书
教师个人教学反思
2016/02/23 职场文书
2016年度创先争优活动总结
2016/04/05 职场文书
python如何利用cv2模块读取显示保存图片
2021/06/04 Python
Python 数据可视化之Bokeh详解
2021/11/02 Python
gtx1650怎么样 gtx1650显卡相当于什么级别
2022/04/08 数码科技
Vue 打包后相对路径的引用问题
2022/06/05 Vue.js