Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python将人民币转换大写的脚本代码
Feb 10 Python
python中argparse模块用法实例详解
Jun 03 Python
python实现字符串连接的三种方法及其效率、适用场景详解
Jan 13 Python
Python基于socket实现简单的即时通讯功能示例
Jan 16 Python
详解pyqt5 动画在QThread线程中无法运行问题
May 05 Python
Python3中函数参数传递方式实例详解
May 05 Python
python里dict变成list实例方法
Jun 26 Python
Python 进程操作之进程间通过队列共享数据,队列Queue简单示例
Oct 11 Python
python实现二分类的卡方分箱示例
Nov 22 Python
在pytorch中对非叶节点的变量计算梯度实例
Jan 10 Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 Python
python实现密码验证合格程序的思路详解
Jun 01 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
PHP中其实也可以用方法链
2011/11/10 PHP
php实现图形显示Ip地址的代码及注释
2014/01/20 PHP
PHP面向对象教程之自定义类
2014/06/10 PHP
PHP实现根据密码长度显示安全条
2017/07/04 PHP
Javascript & DHTML 实例编程(教程)DOM基础和基本API
2007/06/02 Javascript
JavaScript实现拼音排序的方法
2012/11/20 Javascript
JavaScript 实现简单的倒计时弹窗DEMO附图
2014/03/05 Javascript
AngularJS中实现用户访问的身份认证和表单验证功能
2016/04/21 Javascript
jQuery轻松实现表格的隔行变色和点击行变色的实例代码
2016/05/09 Javascript
学习Javascript闭包(Closure)知识
2016/08/07 Javascript
Canvas实现动态的雪花效果
2017/02/13 Javascript
jQuery日期范围选择器附源码下载
2017/05/23 jQuery
JavaScript注册时密码强度校验代码
2017/06/30 Javascript
angular使用bootstrap方法手动启动的实例代码
2017/07/18 Javascript
JS闭包的几种常见形式实例详解
2017/09/16 Javascript
JS组件系列之Gojs组件 前端图形化插件之利器
2017/11/29 Javascript
Vue瀑布流插件的使用示例
2018/09/19 Javascript
详解vue中router-link标签所必备了解的属性
2019/04/15 Javascript
怎么理解wx.navigateTo的events参数使用详情
2020/05/18 Javascript
Python实现分段线性插值
2018/12/17 Python
python 利用浏览器 Cookie 模拟登录的用户访问知乎的方法
2019/07/11 Python
keras load model时出现Missing Layer错误的解决方式
2020/06/11 Python
什么是python的自省
2020/06/21 Python
Application Cache未缓存文件无法访问无法加载问题
2014/05/31 HTML / CSS
德国运动鞋网上商店:Afew Store
2018/01/05 全球购物
英国最大的海报商店:GB Posters
2018/03/20 全球购物
Ariat官网:美国马靴和服装品牌
2019/12/16 全球购物
音乐教学反思
2014/02/02 职场文书
公务员保密承诺书
2014/03/27 职场文书
关于青春的演讲稿三分钟
2014/08/22 职场文书
邓小平理论心得体会
2014/09/09 职场文书
办公室主任四风问题对照检查材料思想汇报
2014/09/28 职场文书
2015财务年终工作总结范文
2015/05/22 职场文书
PHP判断是否是json字符串
2021/04/01 PHP
在pycharm中无法import所安装的库解决方案
2021/05/31 Python
pytorch finetuning 自己的图片进行训练操作
2021/06/05 Python