Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python深入学习之特殊方法与多范式
Aug 31 Python
Python进程通信之匿名管道实例讲解
Apr 11 Python
Windows下Python的Django框架环境部署及应用编写入门
Mar 10 Python
Python Nose框架编写测试用例方法
Oct 26 Python
详解Python里使用正则表达式的ASCII模式
Nov 02 Python
python mac下安装虚拟环境的图文教程
Apr 12 Python
使用Python opencv实现视频与图片的相互转换
Jul 08 Python
Django认证系统实现的web页面实现代码
Aug 12 Python
详解基于python的多张不同宽高图片拼接成大图
Sep 26 Python
opencv调整图像亮度对比度的示例代码
Sep 27 Python
利用Python发送邮件或发带附件的邮件
Nov 12 Python
tensorflow中的数据类型dtype用法说明
May 26 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
77A一级收信机修理记
2021/03/02 无线电
利用PHP实现与ASP Banner组件相似的类
2006/10/09 PHP
基于PHP实现商品成交时发送短信功能
2016/05/11 PHP
CI框架中$this->load->library()用法分析
2016/05/18 PHP
PHP从尾到头打印链表实例讲解
2018/09/27 PHP
PHP与Web页面的交互示例详解一
2020/08/04 PHP
iframe 父窗口和子窗口相互的调用方法集锦
2010/12/15 Javascript
js获取系统的根路径实现介绍
2013/09/08 Javascript
jquery提示效果实例分析
2014/11/25 Javascript
JavaScript中Cookies的相关使用教程
2015/06/04 Javascript
快速学习jQuery插件 Cookie插件使用方法
2015/12/01 Javascript
JS中使用DOM来控制HTML元素
2016/07/31 Javascript
Angular2自定义分页组件
2017/04/19 Javascript
vue自动化表单实例分析
2018/05/06 Javascript
Vue props用法详解(小结)
2018/07/03 Javascript
5分钟快速掌握JS中var、let和const的异同
2018/09/19 Javascript
IE浏览器下JS脚本提交表单后,不能自动提示问题解决方法
2019/06/04 Javascript
Python中str is not callable问题详解及解决办法
2017/02/10 Python
详解python OpenCV学习笔记之直方图均衡化
2018/02/08 Python
sublime python3 输入换行不结束的方法
2018/04/19 Python
Sanic框架安装与简单入门示例
2018/07/16 Python
Python中logging实例讲解
2019/01/17 Python
Python中os模块功能与用法详解
2020/02/26 Python
html5的画布canvas——画出弧线、旋转的图形实例代码+效果图
2013/06/09 HTML / CSS
Betsey Johnson官网:妖娆可爱的连衣裙及鞋子、手袋和配件
2016/12/30 全球购物
全球最大的生存食品、水和装备专用在线市场:BePrepared.com
2020/01/02 全球购物
应届大学生求职的自我评价
2013/11/17 职场文书
小学英语课后反思
2014/04/26 职场文书
幼儿教师求职信
2014/05/24 职场文书
高等教育学专业自荐书
2014/06/17 职场文书
党的群众路线查摆剖析材料
2014/10/10 职场文书
2015年员工试用期工作总结
2015/05/28 职场文书
黑暗中的舞者观后感
2015/06/18 职场文书
生鲜超市—未来中国最具有潜力零售业态
2019/08/02 职场文书
SQLServer之常用函数总结详解
2021/08/30 SQL Server
Shell脚本一键安装Nginx服务自定义Nginx版本
2022/03/20 Servers