tensorflow中的梯度求解及梯度裁剪操作


Posted in Python onMay 26, 2021

1. tensorflow中梯度求解的几种方式

1.1 tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的tensor列表list,例如

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None];参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

1.2 optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装,作用相同,但是tfgradients只返回梯度,compute_gradients返回梯度和可导的变量;tf.compute_gradients是optimizer.minimize()的第一步,optimizer.compute_gradients返回一个[(gradient, variable),…]的元组列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

optimizer = tf.train.GradientDescentOptimizer(learning_rate = 1.0)
self.train_op = optimizer.minimize(self.cost)
sess.run([train_op], feed_dict={x:data, y:labels})

在这个过程中,调用minimize方法的时候,底层进行的工作包括:

(1) 使用tf.optimizer.compute_gradients计算trainable_variables 集合中所有参数的梯度

(2) 用optimizer.apply_gradients来更新计算得到的梯度对应的变量

上面代码等价于下面代码

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
grads_and_vars = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads_and_vars)

1.3 tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

2. 梯度裁剪

如果我们希望对梯度进行截断,那么就要自己计算出梯度,然后进行clip,最后应用到变量上,代码如下所示,接下来我们一一介绍其中的主要步骤

#return a list of trainable variable in you model
params = tf.trainable_variables()

#create an optimizer
opt = tf.train.GradientDescentOptimizer(self.learning_rate)

#compute gradients for params
gradients = tf.gradients(loss, params)

#process gradients
clipped_gradients, norm = tf.clip_by_global_norm(gradients,max_gradient_norm)

train_op = opt.apply_gradients(zip(clipped_gradients, params)))

2.1 tf.clip_by_global_norm介绍

tf.clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None)

 

t_list 表示梯度张量

clip_norm是截取的比率

在应用这个函数之后,t_list[i]的更新公示变为:

global_norm = sqrt(sum(l2norm(t)**2 for t in t_list))
t_list[i] = t_list[i] * clip_norm / max(global_norm, clip_norm)

也就是分为两步:

(1) 计算所有梯度的平方和global_norm

(2) 如果梯度平方和 global_norm 超过我们指定的clip_norm,那么就对梯度进行缩放;否则就按照原本的计算结果

梯度裁剪实例2

loss = w*x*x
optimizer = tf.train.GradientDescentOptimizer(0.1)
grads_and_vars = optimizer.compute_gradients(loss,[w,x])
grads = tf.gradients(loss,[w,x])
# 修正梯度
for i,(gradient,var) in enumerate(grads_and_vars):
    if gradient is not None:
        grads_and_vars[i] = (tf.clip_by_norm(gradient,5),var)
train_op = optimizer.apply_gradients(grads_and_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(grads_and_vars))
     # 梯度修正前[(9.0, 2.0), (12.0, 3.0)];梯度修正后 ,[(5.0, 2.0), (5.0, 3.0)]
    print(sess.run(grads))  #[9.0, 12.0],
    print(train_op)

补充:tensorflow框架中几种计算梯度的方式

1. tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的Tensor列表list,每个张量为sum(dy/dx),即ys关于xs的导数。

例子:

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None]

参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

实例:

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

2. optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装1.

是optimizer.minimize()的第一步,返回(gradient, variable)的列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

3. tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

应用:

1、EM算法,其中M步骤不应涉及通过E步骤的输出的反向传播。

2、Boltzmann机器的对比散度训练,在区分能量函数时,训练不得反向传播通过模型生成样本的图形。

3、对抗性训练,通过对抗性示例生成过程不会发生反向训练。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python BeautifulSoup使用方法详解
Nov 21 Python
Python 不同对象比较大小示例探讨
Aug 21 Python
Python socket.error: [Errno 98] Address already in use的原因和解决方法
Aug 25 Python
在Python中编写数据库模块的教程
Apr 29 Python
详解django中使用定时任务的方法
Sep 27 Python
python实现汽车管理系统
Nov 30 Python
python tkinter canvas 显示图片的示例
Jun 13 Python
Python argparse模块应用实例解析
Nov 15 Python
python GUI库图形界面开发之PyQt5浏览器控件QWebEngineView详细使用方法
Feb 26 Python
Python requests.post方法中data与json参数区别详解
Apr 30 Python
python3跳出一个循环的实例操作
Aug 18 Python
python中Array和DataFrame相互转换的实例讲解
Feb 03 Python
python numpy中multiply与*及matul 的区别说明
May 26 #Python
python文本处理的方案(结巴分词并去除符号)
Django操作cookie的实现
May 26 #Python
pandas中DataFrame检测重复值的实现
python 中的@运算符使用
May 26 #Python
Python 实现定积分与二重定积分的操作
May 26 #Python
python 解决微分方程的操作(数值解法)
You might like
基于PHP+Ajax实现表单验证的详解
2013/06/25 PHP
js下判断 iframe 是否加载完成的完美方法
2010/10/26 Javascript
JavaScript移除数组元素减少长度的方法
2013/09/05 Javascript
在NodeJS中启用ECMAScript 6小结(windos以及Linux)
2014/07/15 NodeJs
node.js中的path.basename方法使用说明
2014/12/09 Javascript
jQuery源码分析之Callbacks详解
2015/03/13 Javascript
jQuery基本选择器和层次选择器学习使用
2017/02/27 Javascript
IntersectionObserver实现图片懒加载的示例
2017/09/29 Javascript
JavaScript简单实现合并两个Json对象的方法示例
2017/10/16 Javascript
Bootstrap fileinput 上传新文件移除时触发服务器同步删除的配置
2018/10/08 Javascript
javascript实现贪吃蛇经典游戏
2020/04/10 Javascript
vue-router 按需加载 component: () => import() 报错的解决
2020/09/22 Javascript
vue中是怎样监听数组变化的
2020/10/24 Javascript
跟老齐学Python之开始真正编程
2014/09/12 Python
20个常用Python运维库和模块
2018/02/12 Python
Django配置celery(非djcelery)执行异步任务和定时任务
2018/07/16 Python
TensorFlow Session会话控制&Variable变量详解
2018/07/30 Python
Python开发的十个小贴士和技巧及长常犯错误
2018/09/27 Python
浅谈Python 多进程默认不能共享全局变量的问题
2019/01/11 Python
Pytorch实现GoogLeNet的方法
2019/08/18 Python
Python3标准库之threading进程中管理并发操作方法
2020/03/30 Python
Python OpenCV去除字母后面的杂线操作
2020/07/05 Python
python 逆向爬虫正确调用 JAR 加密逻辑
2021/01/12 Python
纯CSS3实现表单验证效果(非常不错)
2017/01/18 HTML / CSS
加拿大购物频道:The Shopping Channel
2016/07/21 全球购物
可打印的优惠券、杂货和优惠券代码:Coupons.com
2018/06/12 全球购物
Pretty Little Thing美国:时尚女性服饰
2018/08/27 全球购物
美国环保妈妈、儿童和婴儿用品购物网站:The Tot
2019/11/24 全球购物
集体婚礼证婚词
2014/01/13 职场文书
物理力学求职信
2014/02/18 职场文书
民主评议党员总结
2014/10/20 职场文书
试用期自我评价范文
2015/03/10 职场文书
检察院起诉意见书
2015/05/20 职场文书
2015年国庆节寄语
2015/08/17 职场文书
《牧场之国》教学反思
2016/02/22 职场文书
数据库之SQL技巧整理案例
2021/07/07 SQL Server