Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中的对象,方法,类,实例,函数用法分析
Jan 15 Python
Python单元测试框架unittest简明使用实例
Apr 13 Python
Python网络编程详解
Oct 31 Python
Python设计模式之代理模式简单示例
Jan 09 Python
给你一面国旗 教你用python画中国国旗
Sep 24 Python
基于Python批量生成指定尺寸缩略图代码实例
Nov 20 Python
python操作cfg配置文件方式
Dec 22 Python
mac 上配置Pycharm连接远程服务器并实现使用远程服务器Python解释器的方法
Mar 19 Python
Python3爬虫中Selenium的用法详解
Jul 10 Python
Python列表推导式实现代码实例
Sep 09 Python
使用python操作lmdb对数据读取的实例
Dec 11 Python
Python操作PostgreSql数据库的方法(基本的增删改查)
Dec 29 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
评分9.0以上的动画电影,剧情除了经典还很燃
2020/03/04 日漫
php 静态页面中显示动态内容
2009/08/14 PHP
PHP语言中global和$GLOBALS[]的分析 之二
2012/02/02 PHP
Apache中php.ini的设置方法
2013/02/28 PHP
基于ubuntu下nginx+php+mysql安装配置的具体操作步骤
2013/04/28 PHP
php查找任何页面上的所有链接的方法
2013/12/03 PHP
PHP实现阳历到农历转换的类实例
2015/03/07 PHP
php使用crypt()函数进行加密
2017/06/08 PHP
JQuery 构建客户/服务分离的链接模型中Table分页代码效率初探
2010/01/22 Javascript
基于jquery的跟随屏幕滚动代码
2012/07/24 Javascript
简单漂亮的js弹窗可自由拖拽且兼容大部分浏览器
2013/10/22 Javascript
JS版元素周期表实现方法
2015/08/05 Javascript
在javascript中创建对象的各种模式解析
2016/05/16 Javascript
jQuery插件HighCharts实现的2D堆条状图效果示例【附demo源码下载】
2017/03/14 Javascript
详解React中的组件通信问题
2017/07/31 Javascript
JS实现的加减乘除四则运算计算器示例
2017/08/09 Javascript
JS实现匀加速与匀减速运动的方法示例
2017/09/04 Javascript
关于Ajax的原理以及代码封装详解
2017/09/08 Javascript
详解vue-cli项目中用json-sever搭建mock服务器
2017/11/02 Javascript
通过图带你深入了解vue的响应式原理
2019/06/21 Javascript
python3实现磁盘空间监控
2018/06/21 Python
pymysql模块的使用(增删改查)详解
2019/09/09 Python
Python基于Socket实现简单聊天室
2020/02/17 Python
python GUI库图形界面开发之PyQt5图片显示控件QPixmap详细使用方法与实例
2020/02/27 Python
python seaborn heatmap可视化相关性矩阵实例
2020/06/03 Python
Python实现加密接口测试方法步骤详解
2020/06/05 Python
Python如何使用ElementTree解析xml
2020/10/12 Python
Python爬虫模拟登陆哔哩哔哩(bilibili)并突破点选验证码功能
2020/12/21 Python
html5教程制作简单画板代码分享
2013/12/04 HTML / CSS
深入剖析HTML5 内联框架iFrame
2016/05/04 HTML / CSS
奥地利度假券的专家:we-are.travel
2019/04/10 全球购物
文员个人的求职信范文
2013/09/26 职场文书
奖学金自我鉴定范文
2013/10/03 职场文书
银行员工职业规划范文
2014/01/21 职场文书
党小组推荐意见
2015/06/02 职场文书
使用pytorch实现线性回归
2021/04/11 Python