Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的装饰器详解
Apr 13 Python
Python中利用Scipy包的SIFT方法进行图片识别的实例教程
Jun 03 Python
简单谈谈Python中的几种常见的数据类型
Feb 10 Python
Django自定义认证方式用法示例
Jun 23 Python
python装饰器实例大详解
Oct 25 Python
python使用openpyxl库修改excel表格数据方法
May 03 Python
用Python PIL实现几个简单的图片特效
Jan 18 Python
Python高级编程之继承问题详解(super与mro)
Nov 19 Python
python:批量统计xml中各类目标的数量案例
Mar 10 Python
在python3.64中安装pyinstaller库的方法步骤
Jun 02 Python
使用python操作lmdb对数据读取的实例
Dec 11 Python
使用Python拟合函数曲线
Apr 14 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
特详细的PHPMYADMIN简明安装教程
2008/08/01 PHP
基于preg_match_all采集后数据处理的一点心得笔记(编码转换和正则匹配)
2014/01/31 PHP
PHP防止注入攻击实例分析
2014/11/03 PHP
JQuery中getJSON的使用方法
2010/12/13 Javascript
模拟select的代码
2011/10/19 Javascript
利用JQuery和JS实现奇偶行背景颜色自定义效果
2012/11/19 Javascript
jQuery fadeTo方法调整图片的透明度使用介绍
2013/05/06 Javascript
js操作table示例(个人心得)
2013/11/29 Javascript
Node.js操作redis实现添加查询功能
2017/05/25 Javascript
Bootstrap3.3.7导航栏下拉菜单鼠标滑过展开效果
2017/10/31 Javascript
bootstrap select下拉搜索插件使用方法详解
2017/11/23 Javascript
jQuery niceScroll滚动条错位问题的解决方法
2018/02/03 jQuery
详解微信小程序调用支付接口支付
2019/04/28 Javascript
解决vue单页面修改样式无法覆盖问题
2019/08/05 Javascript
解决vue项目刷新后,导航菜单高亮显示的位置不对问题
2019/11/01 Javascript
Element InputNumber计数器的使用方法
2020/07/27 Javascript
[03:08]TI9战队档案 - Vici Gaming
2019/08/20 DOTA
[01:37]PWL S2开团时刻DAY1&2——这符有毒
2020/11/20 DOTA
python flask实现分页效果
2017/06/27 Python
Python中static相关知识小结
2018/01/02 Python
Python 寻找局部最高点的实现
2019/12/05 Python
python语言是免费还是收费的?
2020/06/15 Python
解决django migrate报错ORA-02000: missing ALWAYS keyword
2020/07/02 Python
阿姆斯特丹城市卡:Amsterdam Pass
2019/12/01 全球购物
捷克建筑材料网上商店:DEK.cz
2021/03/06 全球购物
金蝶的一道SQL笔试题
2012/12/18 面试题
当x.equals(y)等于true时,x.hashCode()与y.hashCode()可以不相等,这句话对不对
2015/05/02 面试题
销售人员个人求职信
2013/09/26 职场文书
《童趣》教学反思
2014/02/19 职场文书
小学生演讲稿大全
2014/04/25 职场文书
高中学校对照检查材料
2014/08/31 职场文书
银行柜员与客户起冲突检讨书
2014/09/27 职场文书
督导岗位职责
2015/02/04 职场文书
党员转正介绍人意见
2015/06/03 职场文书
MySQL系列之十三 MySQL的复制
2021/07/02 MySQL
SQL SERVER中的流程控制语句
2022/05/25 SQL Server