tensorflow中的梯度求解及梯度裁剪操作


Posted in Python onMay 26, 2021

1. tensorflow中梯度求解的几种方式

1.1 tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的tensor列表list,例如

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None];参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

1.2 optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装,作用相同,但是tfgradients只返回梯度,compute_gradients返回梯度和可导的变量;tf.compute_gradients是optimizer.minimize()的第一步,optimizer.compute_gradients返回一个[(gradient, variable),…]的元组列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

optimizer = tf.train.GradientDescentOptimizer(learning_rate = 1.0)
self.train_op = optimizer.minimize(self.cost)
sess.run([train_op], feed_dict={x:data, y:labels})

在这个过程中,调用minimize方法的时候,底层进行的工作包括:

(1) 使用tf.optimizer.compute_gradients计算trainable_variables 集合中所有参数的梯度

(2) 用optimizer.apply_gradients来更新计算得到的梯度对应的变量

上面代码等价于下面代码

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
grads_and_vars = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads_and_vars)

1.3 tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

2. 梯度裁剪

如果我们希望对梯度进行截断,那么就要自己计算出梯度,然后进行clip,最后应用到变量上,代码如下所示,接下来我们一一介绍其中的主要步骤

#return a list of trainable variable in you model
params = tf.trainable_variables()

#create an optimizer
opt = tf.train.GradientDescentOptimizer(self.learning_rate)

#compute gradients for params
gradients = tf.gradients(loss, params)

#process gradients
clipped_gradients, norm = tf.clip_by_global_norm(gradients,max_gradient_norm)

train_op = opt.apply_gradients(zip(clipped_gradients, params)))

2.1 tf.clip_by_global_norm介绍

tf.clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None)

 

t_list 表示梯度张量

clip_norm是截取的比率

在应用这个函数之后,t_list[i]的更新公示变为:

global_norm = sqrt(sum(l2norm(t)**2 for t in t_list))
t_list[i] = t_list[i] * clip_norm / max(global_norm, clip_norm)

也就是分为两步:

(1) 计算所有梯度的平方和global_norm

(2) 如果梯度平方和 global_norm 超过我们指定的clip_norm,那么就对梯度进行缩放;否则就按照原本的计算结果

梯度裁剪实例2

loss = w*x*x
optimizer = tf.train.GradientDescentOptimizer(0.1)
grads_and_vars = optimizer.compute_gradients(loss,[w,x])
grads = tf.gradients(loss,[w,x])
# 修正梯度
for i,(gradient,var) in enumerate(grads_and_vars):
    if gradient is not None:
        grads_and_vars[i] = (tf.clip_by_norm(gradient,5),var)
train_op = optimizer.apply_gradients(grads_and_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(grads_and_vars))
     # 梯度修正前[(9.0, 2.0), (12.0, 3.0)];梯度修正后 ,[(5.0, 2.0), (5.0, 3.0)]
    print(sess.run(grads))  #[9.0, 12.0],
    print(train_op)

补充:tensorflow框架中几种计算梯度的方式

1. tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的Tensor列表list,每个张量为sum(dy/dx),即ys关于xs的导数。

例子:

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None]

参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

实例:

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

2. optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装1.

是optimizer.minimize()的第一步,返回(gradient, variable)的列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

3. tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

应用:

1、EM算法,其中M步骤不应涉及通过E步骤的输出的反向传播。

2、Boltzmann机器的对比散度训练,在区分能量函数时,训练不得反向传播通过模型生成样本的图形。

3、对抗性训练,通过对抗性示例生成过程不会发生反向训练。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python dict remove数组删除(del,pop)
Mar 24 Python
python自动化报告的输出用例详解
May 30 Python
tensorflow实现加载mnist数据集
Sep 08 Python
Python中字符串List按照长度排序
Jul 01 Python
Python实现投影法分割图像示例(一)
Jan 17 Python
python shapely.geometry.polygon任意两个四边形的IOU计算实例
Apr 12 Python
Django REST 异常处理详解
Jul 15 Python
Python数据库封装实现代码示例解析
Sep 05 Python
详解Python3.8+PyQt5+pyqt5-tools+Pycharm配置详细教程
Nov 02 Python
Python接口自动化测试框架运行原理及流程
Nov 30 Python
ASP.NET Core中的配置详解
Feb 05 Python
python 制作磁力搜索工具
Mar 04 Python
python numpy中multiply与*及matul 的区别说明
May 26 #Python
python文本处理的方案(结巴分词并去除符号)
Django操作cookie的实现
May 26 #Python
pandas中DataFrame检测重复值的实现
python 中的@运算符使用
May 26 #Python
Python 实现定积分与二重定积分的操作
May 26 #Python
python 解决微分方程的操作(数值解法)
You might like
FleaPHP的安全设置方法
2008/09/15 PHP
谈PHP生成静态页面分析 模板+缓存+写文件
2009/08/17 PHP
那些年一起学习的PHP(三)
2012/03/22 PHP
完美利用Yii2微信后台开发的系列总结
2016/07/18 PHP
JavaScript中Object和Function的关系小结
2009/09/26 Javascript
深入理解javascript学习笔记(一) 编写高质量代码
2012/08/09 Javascript
JS:window.onload的使用介绍
2013/11/13 Javascript
js实现的点击数量加一可操作数据库
2014/05/09 Javascript
jQuery控制的不同方向的滑动(向左、向右滑动等)
2014/07/18 Javascript
jQuery模拟原生态App上拉刷新下拉加载更多页面及原理
2015/08/10 Javascript
在Web项目中引入Jquery插件报错的完美解决方案(图解)
2016/09/19 Javascript
用jquery的attr方法实现图片切换效果
2017/02/05 Javascript
微信小程序之swiper轮播图中的图片自适应高度的方法
2018/04/23 Javascript
浅谈React Event实现原理
2018/09/20 Javascript
微信小程序上传多图到服务器并获取返回的路径
2019/05/05 Javascript
Vue组件间的通信pubsub-js实现步骤解析
2020/03/11 Javascript
JS内置对象和Math对象知识点详解
2020/04/03 Javascript
在Python的Django框架中编写错误提示页面
2015/07/22 Python
Python无损音乐搜索引擎实现代码
2018/02/02 Python
python实现图书馆研习室自动预约功能
2018/04/27 Python
python 2.7 检测一个网页是否能正常访问的方法
2018/12/26 Python
djang常用查询SQL语句的使用代码
2019/02/15 Python
Araks官网:纽约内衣品牌
2020/10/15 全球购物
军人违纪检讨书
2014/02/04 职场文书
美术教学感言
2014/02/22 职场文书
酒店员工培训方案
2014/06/02 职场文书
公司自我介绍演讲稿
2014/08/21 职场文书
2015廉洁自律个人总结
2015/02/14 职场文书
2015年营业员工作总结
2015/04/23 职场文书
廉政承诺书2015
2015/04/28 职场文书
Django展示可视化图表的多种方式
2021/04/08 Python
Python 多线程之threading 模块的使用
2021/04/14 Python
OpenCV-Python 实现两张图片自动拼接成全景图
2021/06/11 Python
webpack的移动端适配方案小结
2021/07/25 Javascript
SQL语句多表联合查询的方法示例
2022/04/18 MySQL
Golang jwt身份认证
2022/04/20 Golang