Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Linux下Python获取IP地址的代码
Nov 30 Python
python中input()与raw_input()的区别分析
Feb 27 Python
python3使用urllib模块制作网络爬虫
Apr 08 Python
selenium获取当前页面的url、源码、title的方法
Jun 12 Python
python开发之anaconda以及win7下安装gensim的方法
Jul 05 Python
自定义django admin model表单提交的例子
Aug 23 Python
python解析yaml文件过程详解
Aug 30 Python
PyTorch实现AlexNet示例
Jan 14 Python
浅谈Python中range与Numpy中arange的比较
Mar 11 Python
如何配置关联Python 解释器 Anaconda的教程(图解)
Apr 30 Python
Django contrib auth authenticate函数源码解析
Nov 12 Python
Python内置包对JSON文件数据进行编码和解码
Apr 12 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
ASP知识讲座四
2006/10/09 PHP
php相当简单的分页类
2008/10/02 PHP
$_GET['goods_id']+0 的使用详解
2013/06/06 PHP
PHP高并发和大流量解决方案整理
2019/12/24 PHP
javaScript同意等待代码实现心得
2011/01/01 Javascript
通过jquery的$.getJSON做一个跨域ajax请求试验
2011/05/03 Javascript
jQuery 过滤not()与filter()实例代码
2012/05/10 Javascript
jQuery中:has选择器用法实例
2014/12/30 Javascript
jQuery实现瀑布流布局详解(PC和移动端)
2020/09/01 Javascript
教你如何在Node.js中使用jQuery
2016/08/28 Javascript
获取当前月(季度/年)的最后一天(set相关操作及应用)
2016/12/27 Javascript
JS实现给对象动态添加属性的方法
2017/01/05 Javascript
jquery拼接ajax 的json和字符串拼接的方法
2017/03/11 Javascript
详解在express站点中使用ejs模板引擎
2017/09/21 Javascript
nodejs实现的简单web服务器功能示例
2018/03/15 NodeJs
使用jQuery动态设置单选框的选中效果
2018/12/06 jQuery
解析原来浏览器原生支持JS Base64编码解码
2019/08/12 Javascript
layui(1.0.9)文件上传upload,前后端的实例代码
2019/09/26 Javascript
jQuery实现中奖播报功能(让文本滚动起来) 简单设置数值即可
2020/03/20 jQuery
[03:07]完美世界DOTA2联赛PWL DAY10 决赛集锦
2020/11/11 DOTA
Windows下Python使用Pandas模块操作Excel文件的教程
2016/05/31 Python
Python基于OpenCV实现视频的人脸检测
2018/01/23 Python
Python PyCharm如何进行断点调试
2019/07/05 Python
Django 后台获取文件列表 InMemoryUploadedFile的例子
2019/08/07 Python
python multiprocessing多进程变量共享与加锁的实现
2019/10/02 Python
Jupyter加载文件的实现方法
2020/04/14 Python
Python 中 sorted 如何自定义比较逻辑
2021/02/02 Python
日本最新流行服饰网购:Nissen
2016/07/24 全球购物
严选全球尖货,立足香港:Bonpont宝盆
2018/07/24 全球购物
乌克兰在线商店的价格比较:Price.ua
2019/07/26 全球购物
M.M.LaFleur官网:美国职业女装品牌
2020/10/27 全球购物
2015年党支部公开承诺书
2015/01/22 职场文书
音乐教师求职信范文
2015/03/20 职场文书
加薪申请报告范本
2015/05/15 职场文书
详解Vue的列表渲染
2021/11/20 Vue.js
使用python绘制分组对比柱状图
2022/04/21 Python