tensorflow中的梯度求解及梯度裁剪操作


Posted in Python onMay 26, 2021

1. tensorflow中梯度求解的几种方式

1.1 tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的tensor列表list,例如

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None];参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

1.2 optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装,作用相同,但是tfgradients只返回梯度,compute_gradients返回梯度和可导的变量;tf.compute_gradients是optimizer.minimize()的第一步,optimizer.compute_gradients返回一个[(gradient, variable),…]的元组列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

optimizer = tf.train.GradientDescentOptimizer(learning_rate = 1.0)
self.train_op = optimizer.minimize(self.cost)
sess.run([train_op], feed_dict={x:data, y:labels})

在这个过程中,调用minimize方法的时候,底层进行的工作包括:

(1) 使用tf.optimizer.compute_gradients计算trainable_variables 集合中所有参数的梯度

(2) 用optimizer.apply_gradients来更新计算得到的梯度对应的变量

上面代码等价于下面代码

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
grads_and_vars = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads_and_vars)

1.3 tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

2. 梯度裁剪

如果我们希望对梯度进行截断,那么就要自己计算出梯度,然后进行clip,最后应用到变量上,代码如下所示,接下来我们一一介绍其中的主要步骤

#return a list of trainable variable in you model
params = tf.trainable_variables()

#create an optimizer
opt = tf.train.GradientDescentOptimizer(self.learning_rate)

#compute gradients for params
gradients = tf.gradients(loss, params)

#process gradients
clipped_gradients, norm = tf.clip_by_global_norm(gradients,max_gradient_norm)

train_op = opt.apply_gradients(zip(clipped_gradients, params)))

2.1 tf.clip_by_global_norm介绍

tf.clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None)

 

t_list 表示梯度张量

clip_norm是截取的比率

在应用这个函数之后,t_list[i]的更新公示变为:

global_norm = sqrt(sum(l2norm(t)**2 for t in t_list))
t_list[i] = t_list[i] * clip_norm / max(global_norm, clip_norm)

也就是分为两步:

(1) 计算所有梯度的平方和global_norm

(2) 如果梯度平方和 global_norm 超过我们指定的clip_norm,那么就对梯度进行缩放;否则就按照原本的计算结果

梯度裁剪实例2

loss = w*x*x
optimizer = tf.train.GradientDescentOptimizer(0.1)
grads_and_vars = optimizer.compute_gradients(loss,[w,x])
grads = tf.gradients(loss,[w,x])
# 修正梯度
for i,(gradient,var) in enumerate(grads_and_vars):
    if gradient is not None:
        grads_and_vars[i] = (tf.clip_by_norm(gradient,5),var)
train_op = optimizer.apply_gradients(grads_and_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(grads_and_vars))
     # 梯度修正前[(9.0, 2.0), (12.0, 3.0)];梯度修正后 ,[(5.0, 2.0), (5.0, 3.0)]
    print(sess.run(grads))  #[9.0, 12.0],
    print(train_op)

补充:tensorflow框架中几种计算梯度的方式

1. tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的Tensor列表list,每个张量为sum(dy/dx),即ys关于xs的导数。

例子:

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None]

参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

实例:

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

2. optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装1.

是optimizer.minimize()的第一步,返回(gradient, variable)的列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

3. tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

应用:

1、EM算法,其中M步骤不应涉及通过E步骤的输出的反向传播。

2、Boltzmann机器的对比散度训练,在区分能量函数时,训练不得反向传播通过模型生成样本的图形。

3、对抗性训练,通过对抗性示例生成过程不会发生反向训练。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python时区设置方法与pytz查询时区教程
Nov 27 Python
Python实现从脚本里运行scrapy的方法
Apr 07 Python
详解详解Python中writelines()方法的使用
May 25 Python
python 实现删除文件或文件夹实例详解
Dec 04 Python
Python实现输出程序执行进度百分比的方法
Sep 16 Python
Python3 安装PyQt5及exe打包图文教程
Jan 08 Python
Python 中的 import 机制之实现远程导入模块
Oct 29 Python
Python 依赖库太多了该如何管理
Nov 08 Python
python自动下载图片的方法示例
Mar 25 Python
python GUI计算器的实现
Oct 09 Python
浅谈Selenium+Webdriver 常用的元素定位方式
Jan 13 Python
python 管理系统实现mysql交互的示例代码
Dec 06 Python
python numpy中multiply与*及matul 的区别说明
May 26 #Python
python文本处理的方案(结巴分词并去除符号)
Django操作cookie的实现
May 26 #Python
pandas中DataFrame检测重复值的实现
python 中的@运算符使用
May 26 #Python
Python 实现定积分与二重定积分的操作
May 26 #Python
python 解决微分方程的操作(数值解法)
You might like
PHP删除目录及目录下所有文件的方法详解
2013/06/06 PHP
WordPress中is_singular()函数简介
2015/02/05 PHP
使用WordPress发送电子邮件的相关PHP函数用法解析
2015/12/15 PHP
PHP字符串逆序排列实现方法小结【strrev函数,二分法,循环法,递归法】
2017/01/13 PHP
Laravel定时任务的每秒执行代码
2019/10/22 PHP
TP5(thinkPHP5)框架使用ajax实现与后台数据交互的方法小结
2020/02/10 PHP
php设计模式之职责链模式实例分析【星际争霸游戏案例】
2020/03/27 PHP
不错的asp中显示新闻的功能
2006/10/13 Javascript
Javascript与vbscript数据共享
2007/01/09 Javascript
js批量设置样式的三种方法不推荐使用with
2013/02/25 Javascript
JavaScript 操作table,可以新增行和列并且隔一行换背景色代码分享
2013/07/05 Javascript
浅析Node.js的Stream模块中的Readable对象
2015/07/29 Javascript
JavaScript中的原型prototype完全解析
2016/05/10 Javascript
js 弹出对话框(遮罩)透明,可拖动的简单实例
2016/07/11 Javascript
原生javascript 学习之js变量全面了解
2016/07/14 Javascript
JS中数组重排序方法
2016/11/11 Javascript
js实现界面向原生界面发消息并跳转功能
2016/11/22 Javascript
详解angular ui-grid之过滤器设置
2017/06/07 Javascript
在axios中使用params传参的时候传入数组的方法
2018/09/25 Javascript
基于vue和react的spa进行按需加载的实现方法
2018/09/29 Javascript
详解js location.href和window.open的几种用法和区别
2019/12/02 Javascript
webpack+vue-cil 中proxyTable配置接口地址代理操作
2020/07/18 Javascript
[01:34]DOTA2 7.22版本新增神杖效果一览(敏捷英雄篇)
2019/05/28 DOTA
python数据结构之二叉树的统计与转换实例
2014/04/29 Python
python使用PyGame播放Midi和Mp3文件的方法
2015/04/24 Python
Python实现统计单词出现的个数
2015/05/28 Python
Python解析json之ValueError: Expecting property name enclosed in double quotes: line 1 column 2(char 1)
2017/07/06 Python
Pandas之Fillna填充缺失数据的方法
2019/06/25 Python
Python3 字典dictionary入门基础附实例
2020/02/10 Python
Python+OpenCV图像处理——实现轮廓发现
2020/10/23 Python
python 自动识别并连接串口的实现
2021/01/19 Python
美国流行背包品牌:JanSport(杰斯伯)
2018/03/02 全球购物
库房主管岗位职责
2013/12/31 职场文书
中学感恩教育活动总结
2015/05/05 职场文书
2016年安全月活动总结
2016/04/06 职场文书
VS2019连接MySQL数据库的过程及常见问题总结
2021/11/27 MySQL