Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python基础教程之字典操作详解
Mar 25 Python
python中类的一些方法分析
Sep 25 Python
Python3通过Luhn算法快速验证信用卡卡号的方法
May 14 Python
python爬虫入门教程--正则表达式完全指南(五)
May 25 Python
分享Python切分字符串的一个不错方法
Dec 14 Python
python selenium执行所有测试用例并生成报告的方法
Feb 13 Python
django用户登录验证的完整示例代码
Jul 21 Python
python装饰器代替set get方法实例
Dec 19 Python
Python常用库大全及简要说明
Jan 17 Python
Python3使用xlrd、xlwt处理Excel方法数据
Feb 28 Python
Python定义一个Actor任务
Jul 29 Python
详解Python生成器和基于生成器的协程
Jun 03 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
PHP 导出Excel示例分享
2014/08/18 PHP
PHP随机生成唯一HASH值自定义函数
2015/04/20 PHP
mysql alter table命令修改表结构实例详解
2016/09/24 PHP
PHP带节点操作的无限分类实现方法详解
2016/11/09 PHP
PHP Swoole异步Redis客户端实现方法示例
2019/10/24 PHP
Javascript 遍历对象中的子对象
2009/07/03 Javascript
js实现兼容IE和FF的上下层的移动
2015/05/04 Javascript
javascript 常见功能汇总
2015/06/11 Javascript
AngularJs 国际化(I18n/L10n)详解
2016/09/01 Javascript
折叠菜单及选择器的运用
2017/02/03 Javascript
JS实现无缝循环marquee滚动效果
2017/05/22 Javascript
AngularJS通过ng-Img-Crop实现头像截取的示例
2017/08/17 Javascript
Vue自定义指令使用方法详解
2017/08/21 Javascript
详解js中Array的方法及技巧
2018/09/12 Javascript
vue路由前进后退动画效果的实现代码
2018/12/10 Javascript
JS异步错误捕获的一些事小结
2019/04/26 Javascript
js实现内置计时器
2019/12/16 Javascript
Vue设置长时间未操作登录自动到期返回登录页
2020/01/22 Javascript
聊聊vue 中的v-on参数问题
2021/01/29 Vue.js
Python使用QRCode模块生成二维码实例详解
2017/06/14 Python
详解django中自定义标签和过滤器
2017/07/03 Python
python中的for循环
2018/09/28 Python
python3.6+django2.0+mysql搭建网站过程详解
2019/07/24 Python
SpringBoot实现登录注册常见问题解决方案
2020/03/04 Python
python3.6.8 + pycharm + PyQt5 环境搭建的图文教程
2020/06/11 Python
html5响应式开发自动计算fontSize的方法
2020/01/13 HTML / CSS
Pretty You London官网:英国拖鞋和睡衣品牌
2019/05/08 全球购物
在网络中有两台主机A和B,并通过路由器和其他交换设备连接起来,已经确认物理连接正确无误,怎么来测试这两台机器是否连通?如果不通,怎么来判断故障点?怎么排
2014/01/13 面试题
白酒业务员岗位职责
2013/12/27 职场文书
政府信息公开实施方案
2014/05/09 职场文书
护士优质服务演讲稿
2014/08/26 职场文书
2014年中职班主任工作总结
2014/12/16 职场文书
2015年电工工作总结
2015/04/10 职场文书
2016年元旦致辞
2015/08/01 职场文书
高一语文教学反思
2016/02/16 职场文书
2016年“我们的节日·中秋节”活动总结
2016/04/05 职场文书