Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中运行并行任务技巧
Feb 26 Python
python基础_文件操作实现全文或单行替换的方法
Sep 04 Python
Python数据结构与算法之列表(链表,linked list)简单实现
Oct 30 Python
python机器学习之随机森林(七)
Mar 26 Python
对python判断是否回文数的实例详解
Feb 08 Python
使用python绘制二元函数图像的实例
Feb 12 Python
解决Python计算矩阵乘向量,矩阵乘实数的一些小错误
Aug 26 Python
在Python中使用turtle绘制多个同心圆示例
Nov 23 Python
Scrapy框架基本命令与settings.py设置
Feb 06 Python
在PyTorch中使用标签平滑正则化的问题
Apr 03 Python
如何通过Python3和ssl实现加密通信功能
May 09 Python
python爬取豆瓣电影TOP250数据
May 23 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
模板引擎smarty工作原理以及使用示例
2014/05/25 PHP
jQuery EasyUI NumberBox(数字框)的用法
2010/07/08 Javascript
JavaScript格式化数字的函数代码
2010/11/30 Javascript
基于jquery自定义的漂亮单选按钮RadioButton
2013/11/19 Javascript
javascript 控制input只允许输入的各种指定内容
2014/06/19 Javascript
javascript手风琴下拉菜单实现代码
2015/11/12 Javascript
全面介绍javascript实用技巧及单竖杠
2016/07/18 Javascript
前端弹出对话框 js实现ajax交互
2016/09/09 Javascript
简单谈谈JS数组中的indexOf方法
2016/10/13 Javascript
JS实现图片上传预览功能
2016/11/21 Javascript
JavaScript实现翻页功能(附效果图)
2017/02/16 Javascript
微信小程序地图(map)组件点击(tap)获取经纬度的方法
2019/01/10 Javascript
JS实现的贪吃蛇游戏案例详解
2019/05/01 Javascript
初学node.js中实现删除用户路由
2019/05/27 Javascript
vue App.vue中的公共组件改变值触发其他组件或.vue页面监听
2019/05/31 Javascript
bat和python批量重命名文件的实现代码
2016/05/19 Python
python实现的MySQL增删改查操作实例小结
2018/12/19 Python
详解10个可以快速用Python进行数据分析的小技巧
2019/06/24 Python
Python使用APScheduler实现定时任务过程解析
2019/09/11 Python
Python中os模块功能与用法详解
2020/02/26 Python
python轮询机制控制led实例
2020/05/03 Python
使用CSS3的::selection改变选中文本颜色的方法
2015/09/29 HTML / CSS
HTML5 UTF-8 中文乱码的解决方法
2013/11/18 HTML / CSS
html5唤醒APP小记
2019/03/27 HTML / CSS
英国皇家造币厂:The Royal Mint
2018/10/05 全球购物
奥地利婴儿用品和玩具购物网站:baby-markt.at
2020/01/26 全球购物
公休请假条
2014/04/11 职场文书
合作经营协议书范本
2014/04/17 职场文书
汽车检测与维修专业求职信
2014/07/04 职场文书
教师节倡议书
2014/08/30 职场文书
合理化建议书
2015/02/04 职场文书
医德医风学习心得体会
2016/01/25 职场文书
小学体育跳绳课教学反思
2016/02/16 职场文书
《曾国藩家书》读后感——读家书,立家风
2019/08/21 职场文书
用python画城市轮播地图
2021/05/28 Python
【海涛教你打DOTA】剑圣第一人称视角解说
2022/04/01 DOTA