Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python之PyUnit单元测试实例
Oct 11 Python
python中黄金分割法实现方法
May 06 Python
python访问mysql数据库的实现方法(2则示例)
Jan 06 Python
python实现多线程抓取知乎用户
Dec 12 Python
解决pycharm无法识别本地site-packages的问题
Oct 13 Python
详解python3安装pillow后报错没有pillow模块以及没有PIL模块问题解决
Apr 17 Python
python的移位操作实现详解
Aug 21 Python
python 实现方阵的对角线遍历示例
Nov 29 Python
tensorflow 模型权重导出实例
Jan 24 Python
在Windows上安装和配置 Jupyter Lab 作为桌面级应用程序教程
Apr 22 Python
PyQt5连接MySQL及QMYSQL driver not loaded错误解决
Apr 29 Python
Pytorch DataLoader shuffle验证方式
Jun 02 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
php 字符过滤类,用于过滤各类用户输入的数据
2009/05/27 PHP
PHP在引号前面添加反斜杠(PHP去除反斜杠)
2013/09/28 PHP
PHP数据过滤的方法
2013/10/30 PHP
PHP Header失效的原因分析及解决方法
2016/11/16 PHP
Centos 6.5下PHP 5.3安装ffmpeg扩展的步骤详解
2017/03/02 PHP
PHP实现用户登录的案例代码
2018/05/10 PHP
jQuery设置div一直在页面顶部显示的方法
2013/10/24 Javascript
jquery动态调整div大小使其宽度始终为浏览器宽度
2014/06/06 Javascript
用jquery修复在iframe下的页面锚点失效问题
2014/08/22 Javascript
jQuery中;function($,undefined) 前面的分号的用处
2014/12/17 Javascript
深入理解JavaScript系列(47):对象创建模式(上篇)
2015/03/04 Javascript
jQuery中extend函数详解
2015/07/13 Javascript
jQuery.trim() 函数及trim()用法详解
2015/10/26 Javascript
跟我学习javascript的异步脚本加载
2015/11/20 Javascript
javascript精确统计网站访问量实例代码
2015/12/19 Javascript
全面解析Bootstrap中nav、collapse的使用方法
2016/05/22 Javascript
基于angular中的重要指令详解($eval,$parse和$compile)
2016/10/21 Javascript
Vue运用transition实现过渡动画
2019/05/06 Javascript
Python中方法链的使用方法
2016/02/23 Python
放弃 Python 转向 Go语言有人给出了 9 大理由
2017/10/20 Python
python爬取指定微信公众号文章
2018/12/20 Python
Python pip 安装与使用(安装、更新、删除)
2019/10/06 Python
python图形用户接口实例详解
2019/12/16 Python
django 实现手动存储文件到model的FileField
2020/03/30 Python
CSS3混合模式mix-blend-mode/background-blend-mode简介
2018/03/15 HTML / CSS
解释一下ArrayList Vector和LinkedList的实现和区别
2013/04/26 面试题
职称自我鉴定
2013/10/15 职场文书
车间主管岗位职责
2013/11/14 职场文书
岗位明星事迹材料
2014/05/18 职场文书
保护水资源的标语
2014/06/17 职场文书
2015年打非治违工作总结
2015/04/02 职场文书
转学证明范本
2015/06/19 职场文书
暂住证明怎么写
2015/06/19 职场文书
导游词之介休绵山
2019/12/31 职场文书
Java并发编程必备之Future机制
2021/06/30 Java/Android
python机器学习创建基于规则聊天机器人过程示例详解
2021/11/02 Python