Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
采用python实现简单QQ单用户机器人的方法
Jul 03 Python
Python画图学习入门教程
Jul 01 Python
Ubuntu 下 vim 搭建python 环境 配置
Jun 12 Python
python3解析库pyquery的深入讲解
Jun 26 Python
实例分析python3实现并发访问水平切分表
Sep 29 Python
Python将文字转成语音并读出来的实例详解
Jul 15 Python
一行python实现树形结构的方法
Aug 09 Python
python如何使用socketserver模块实现并发聊天
Dec 14 Python
Python基于os.environ从windows获取环境变量
Jun 09 Python
pandas.DataFrame.drop_duplicates 用法介绍
Jul 06 Python
python 从list中随机取值的方法
Nov 16 Python
python通过cython加密代码
Dec 11 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
php visitFile()遍历指定文件夹函数
2010/08/21 PHP
应用开发中涉及到的css和php笔记分享
2011/08/02 PHP
php使用scandir()函数扫描指定目录下所有文件示例
2019/06/08 PHP
php常用字符串查找函数strstr()与strpos()实例分析
2019/06/21 PHP
firefox浏览器下javascript 拖动层效果与原理分析代码
2007/12/04 Javascript
jcarousellite.js 基于Jquery的图片无缝滚动插件
2010/12/30 Javascript
document.createElement()用法
2013/03/13 Javascript
jquery单选框radio绑定click事件实现方法
2015/01/14 Javascript
Node.js中的缓冲与流模块详细介绍
2015/02/11 Javascript
基于JS实现无缝滚动思路及代码分享
2016/06/07 Javascript
BootStrap fileinput.js文件上传组件实例代码
2017/02/20 Javascript
Vue keepAlive 数据缓存工具实现返回上一个页面浏览的位置
2019/05/10 Javascript
微信小程序如何调用新闻接口实现列表循环
2019/07/02 Javascript
BootStrap前端框架使用方法详解
2020/02/26 Javascript
从Node.js事件触发器到Vue自定义事件的深入讲解
2020/06/26 Javascript
vue监听dom大小改变案例
2020/07/29 Javascript
[50:02]完美世界DOTA2联赛循环赛 Magma vs IO BO2第一场 11.01
2020/11/02 DOTA
Python的Tornado框架实现异步非阻塞访问数据库的示例
2016/06/30 Python
Python 通配符删除文件的实例
2018/04/24 Python
对python3 sort sorted 函数的应用详解
2019/06/27 Python
python3+opencv 使用灰度直方图来判断图片的亮暗操作
2020/06/02 Python
Flask中sqlalchemy模块的实例用法
2020/08/02 Python
JupyterNotebook 输出窗口的显示效果调整实现
2020/09/22 Python
探究 canvas 绘图中撤销(undo)功能的实现方式详解
2018/05/17 HTML / CSS
法国女性内衣购物网站:Glamuse
2019/05/13 全球购物
英国计算机商店:Technextday
2019/12/28 全球购物
Perfume’s Club中文官网:西班牙美妆在线零售品牌
2020/08/24 全球购物
煤矿班组长岗位职责
2013/12/29 职场文书
管事部库房保管员岗位职责
2014/02/21 职场文书
优秀求职信
2014/05/29 职场文书
银行求职信模板
2015/03/20 职场文书
同学聚会通知短信
2015/04/20 职场文书
Python爬虫基础之爬虫的分类知识总结
2021/05/13 Python
python实现A*寻路算法
2021/06/13 Python
mysql如何能有效防止删库跑路
2021/10/05 MySQL
Nginx静态压缩和代码压缩提高访问速度详解
2022/05/30 Servers