Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python如何为图片添加水印
Nov 25 Python
Python+matplotlib实现计算两个信号的交叉谱密度实例
Jan 08 Python
python如何重载模块实例解析
Jan 25 Python
Python requests发送post请求的一些疑点
May 20 Python
从运行效率与开发效率比较Python和C++
Dec 14 Python
Django的用户模块与权限系统的示例代码
Jul 24 Python
基于Python获取城市近7天天气预报
Nov 26 Python
python应用Axes3D绘图(批量梯度下降算法)
Mar 25 Python
Python flask路由间传递变量实例详解
Jun 03 Python
导致python中import错误的原因是什么
Jul 01 Python
python实现简单遗传算法
Sep 18 Python
Django自带用户认证系统使用方法解析
Nov 12 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
php中使用接口实现工厂设计模式的代码
2012/06/17 PHP
php生成缩略图示例代码分享(使用gd库实现)
2014/01/20 PHP
PHP 中使用explode()函数切割字符串为数组的示例
2017/05/06 PHP
PHPstorm启用自动换行的方法详解(IDE)
2020/09/17 PHP
jquery插件冲突(jquery.noconflict)解决方法分享
2014/03/20 Javascript
JavaScript中innerHTML,innerText,outerHTML的用法及区别
2015/09/01 Javascript
apply和call方法定义及apply和call方法的区别
2015/11/15 Javascript
jquery插件autocomplete用法示例
2016/07/01 Javascript
javascript实现简易计算器
2017/02/01 Javascript
Bootstrap里的文件分别代表什么意思及其引用方法
2017/05/01 Javascript
Vue.js+Layer表格数据绑定与实现更新的实例
2018/03/07 Javascript
vue中v-cloak解决刷新或者加载出现闪烁问题(显示变量)
2018/04/20 Javascript
JavaScript基于对象方法实现数组去重及排序操作示例
2018/07/10 Javascript
解决vue的 v-for 循环中图片加载路径问题
2018/09/03 Javascript
JavaScript Array对象基本方法详解
2019/09/03 Javascript
原生js基于canvas实现一个简单的前端截图工具代码实例
2019/09/10 Javascript
python基础教程之lambda表达式使用方法
2014/02/12 Python
Python中的ConfigParser模块使用详解
2015/05/04 Python
python集合用法实例分析
2015/05/30 Python
Python 数据结构之堆栈实例代码
2017/01/22 Python
详解Python 2.6 升级至 Python 2.7 的实践心得
2017/04/27 Python
Python爬虫设置代理IP(图文)
2018/12/23 Python
python excel和yaml文件的读取封装
2021/01/12 Python
Lands’ End官网:经典的美国生活方式品牌
2016/08/14 全球购物
Carolina工作鞋官网:Carolina Footwear
2019/03/14 全球购物
几个SQL的面试题
2014/03/08 面试题
关于人生的感言
2014/01/17 职场文书
优秀小学生家长评语
2014/01/30 职场文书
三方协议书范本
2014/04/22 职场文书
2014乡镇干部纪律作风整顿思想汇报
2014/09/13 职场文书
婚前协议书范本
2014/10/27 职场文书
男方婚前保证书
2015/02/28 职场文书
离婚律师函范本
2015/05/27 职场文书
认识实习感想
2015/08/10 职场文书
三严三实学习心得体会(精选N篇)
2016/01/05 职场文书
Vue项目打包、合并及压缩优化网页响应速度
2021/07/07 Vue.js