tensorflow中的梯度求解及梯度裁剪操作


Posted in Python onMay 26, 2021

1. tensorflow中梯度求解的几种方式

1.1 tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的tensor列表list,例如

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None];参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

1.2 optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装,作用相同,但是tfgradients只返回梯度,compute_gradients返回梯度和可导的变量;tf.compute_gradients是optimizer.minimize()的第一步,optimizer.compute_gradients返回一个[(gradient, variable),…]的元组列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

optimizer = tf.train.GradientDescentOptimizer(learning_rate = 1.0)
self.train_op = optimizer.minimize(self.cost)
sess.run([train_op], feed_dict={x:data, y:labels})

在这个过程中,调用minimize方法的时候,底层进行的工作包括:

(1) 使用tf.optimizer.compute_gradients计算trainable_variables 集合中所有参数的梯度

(2) 用optimizer.apply_gradients来更新计算得到的梯度对应的变量

上面代码等价于下面代码

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
grads_and_vars = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads_and_vars)

1.3 tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

2. 梯度裁剪

如果我们希望对梯度进行截断,那么就要自己计算出梯度,然后进行clip,最后应用到变量上,代码如下所示,接下来我们一一介绍其中的主要步骤

#return a list of trainable variable in you model
params = tf.trainable_variables()

#create an optimizer
opt = tf.train.GradientDescentOptimizer(self.learning_rate)

#compute gradients for params
gradients = tf.gradients(loss, params)

#process gradients
clipped_gradients, norm = tf.clip_by_global_norm(gradients,max_gradient_norm)

train_op = opt.apply_gradients(zip(clipped_gradients, params)))

2.1 tf.clip_by_global_norm介绍

tf.clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None)

 

t_list 表示梯度张量

clip_norm是截取的比率

在应用这个函数之后,t_list[i]的更新公示变为:

global_norm = sqrt(sum(l2norm(t)**2 for t in t_list))
t_list[i] = t_list[i] * clip_norm / max(global_norm, clip_norm)

也就是分为两步:

(1) 计算所有梯度的平方和global_norm

(2) 如果梯度平方和 global_norm 超过我们指定的clip_norm,那么就对梯度进行缩放;否则就按照原本的计算结果

梯度裁剪实例2

loss = w*x*x
optimizer = tf.train.GradientDescentOptimizer(0.1)
grads_and_vars = optimizer.compute_gradients(loss,[w,x])
grads = tf.gradients(loss,[w,x])
# 修正梯度
for i,(gradient,var) in enumerate(grads_and_vars):
    if gradient is not None:
        grads_and_vars[i] = (tf.clip_by_norm(gradient,5),var)
train_op = optimizer.apply_gradients(grads_and_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(grads_and_vars))
     # 梯度修正前[(9.0, 2.0), (12.0, 3.0)];梯度修正后 ,[(5.0, 2.0), (5.0, 3.0)]
    print(sess.run(grads))  #[9.0, 12.0],
    print(train_op)

补充:tensorflow框架中几种计算梯度的方式

1. tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的Tensor列表list,每个张量为sum(dy/dx),即ys关于xs的导数。

例子:

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None]

参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

实例:

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

2. optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装1.

是optimizer.minimize()的第一步,返回(gradient, variable)的列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

3. tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

应用:

1、EM算法,其中M步骤不应涉及通过E步骤的输出的反向传播。

2、Boltzmann机器的对比散度训练,在区分能量函数时,训练不得反向传播通过模型生成样本的图形。

3、对抗性训练,通过对抗性示例生成过程不会发生反向训练。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现字符串的逆序 C++字符串逆序算法
May 28 Python
python利用微信公众号实现报警功能
Jun 10 Python
PyQt5实现QLineEdit添加clicked信号的方法
Jun 25 Python
简单了解python 邮件模块的使用方法
Jul 24 Python
python 图片二值化处理(处理后为纯黑白的图片)
Nov 01 Python
pytorch sampler对数据进行采样的实现
Dec 31 Python
解决pycharm编辑区显示yaml文件层级结构遇中文乱码问题
Apr 27 Python
基于python 取余问题(%)详解
Jun 03 Python
python3读取autocad图形文件.py实例
Jun 05 Python
python爬虫中PhantomJS加载页面的实例方法
Nov 12 Python
Python解析微信dat文件的方法
Nov 30 Python
详解python的xlwings库读写excel操作总结
Feb 26 Python
python numpy中multiply与*及matul 的区别说明
May 26 #Python
python文本处理的方案(结巴分词并去除符号)
Django操作cookie的实现
May 26 #Python
pandas中DataFrame检测重复值的实现
python 中的@运算符使用
May 26 #Python
Python 实现定积分与二重定积分的操作
May 26 #Python
python 解决微分方程的操作(数值解法)
You might like
《PHP编程最快明白》第七讲:php图片验证码与缩略图
2010/11/01 PHP
php中get_headers函数的作用及用法的详细介绍
2013/04/27 PHP
IIS安装Apache伪静态插件的具体操作图文
2013/07/01 PHP
php获取从百度搜索进入网站的关键词的详细代码
2014/01/08 PHP
利用php做服务器和web前端的界面进行交互
2016/10/31 PHP
php中strlen和mb_strlen用法实例分析
2016/11/12 PHP
根据key删除数组中指定的元素实现方法
2017/03/02 PHP
PHP设计模式之建造者模式定义与用法简单示例
2018/08/13 PHP
Auntion-TableSort国人写的一个javascript表格排序的东西
2007/11/12 Javascript
JavaScript的变量作用域深入理解
2009/10/25 Javascript
解决表单中第一个非隐藏的元素获得焦点的一个方案
2009/10/26 Javascript
最新的10款jQuery内容滑块插件分享
2011/09/18 Javascript
javascript实现dom动态创建省市纵向列表菜单的方法
2015/05/14 Javascript
JQuery Mobile实现导航栏和页脚
2016/03/09 Javascript
jQuery中事件与动画的总结分享
2016/05/24 Javascript
JSON键值对序列化和反序列化解析
2017/01/24 Javascript
详解本地Node.js服务器作为api服务器的解决办法
2017/02/28 Javascript
jQuery复合事件结合toggle()方法的用法示例
2017/06/10 jQuery
Easyui在treegrid添加控件的实现方法
2017/06/23 Javascript
浅谈express 中间件机制及实现原理
2017/08/31 Javascript
详解用webpack的CommonsChunkPlugin提取公共代码的3种方式
2017/11/09 Javascript
JS实现留言板功能[楼层效果展示]
2017/12/27 Javascript
详解vue.js根据不同环境(正式、测试)打包到不同目录
2018/07/13 Javascript
vue组件挂载到全局方法的示例代码
2018/08/02 Javascript
详解如何解决Vue和vue-template-compiler版本之间的问题
2018/09/17 Javascript
layui动态渲染生成select的option值方法
2019/09/23 Javascript
JS+CSS实现3D切割轮播图
2020/03/21 Javascript
详解Typescript 内置的模块导入兼容方式
2020/05/31 Javascript
javascript实现扫雷简易版
2020/08/18 Javascript
vue 防止页面加载时看到花括号的解决操作
2020/11/09 Javascript
Python yield 使用浅析
2015/05/28 Python
numpy数组拼接简单示例
2017/12/15 Python
使用django和vue进行数据交互的方法步骤
2019/11/11 Python
python使用selenium爬虫知乎的方法示例
2020/10/28 Python
美国最受欢迎的度假租赁网站:VRBO
2016/08/02 全球购物
就业自荐书
2013/12/05 职场文书