Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中list列表的一些进阶使用方法介绍
Aug 15 Python
Python时间模块datetime、time、calendar的使用方法
Jan 13 Python
python之Socket网络编程详解
Sep 29 Python
python traceback捕获并打印异常的方法
Aug 31 Python
Django admin.py 在修改/添加表单界面显示额外字段的方法
Aug 22 Python
Python pygame绘制文字制作滚动文字过程解析
Dec 12 Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 Python
Python小白学习爬虫常用请求报头
Jun 03 Python
python如何调用java类
Jul 05 Python
Python 如何调试程序崩溃错误
Aug 03 Python
基于python图书馆管理系统设计实例详解
Aug 05 Python
python如何快速拼接字符串
Oct 28 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
php小偷相关截取函数备忘
2010/11/28 PHP
Smarty foreach控制循环次数的一些方法
2015/07/01 PHP
各种快递查询--Api接口
2016/04/26 PHP
PHP 7.0新增加的特性介绍
2017/06/08 PHP
Yii2 队列 shmilyzxt/yii2-queue 简单概述
2017/08/02 PHP
Laravel5.5 手动分页和自定义分页样式的简单实现
2019/10/15 PHP
jquery高效反选具体实现
2013/05/05 Javascript
Ext JS 4实现带week(星期)的日期选择控件(实战二)
2013/08/21 Javascript
利用jQuery实现可以编辑的表格
2014/05/26 Javascript
JS判断网页广告是否被浏览器拦截过滤的代码
2015/04/05 Javascript
jQuery插件实现多级联动菜单效果
2015/12/01 Javascript
JavaScript实现瀑布流布局
2020/06/28 Javascript
jQuery表单对象属性过滤选择器实例详解
2016/09/13 Javascript
Angular2 组件通信的实例代码
2017/06/23 Javascript
详解Vue如何支持JSX语法
2017/11/10 Javascript
nodejs读取本地中文json文件出现乱码解决方法
2018/10/10 NodeJs
layui table动态表头 改变表格头部 重新加载表格的方法
2019/09/21 Javascript
Vue el-autocomplete远程搜索下拉框并实现自动填充功能(推荐)
2019/10/25 Javascript
Vue SPA 首屏优化方案
2021/02/26 Vue.js
Python制作exe文件简单流程
2019/01/24 Python
Pycharm安装第三方库失败解决方案
2020/11/17 Python
利用css3实现的简单的鼠标悬停按钮
2014/11/04 HTML / CSS
CSS3简单实现照片墙
2014/12/12 HTML / CSS
无需JS和jQuery代码实现CSS3鼠标浮动放大图片
2016/11/21 HTML / CSS
Michael Kors加拿大官网:购买设计师手袋、手表、鞋子、服装等
2019/03/16 全球购物
自荐信模版
2013/10/24 职场文书
预备党员表决心书
2014/03/11 职场文书
教师党员岗位承诺书
2014/05/29 职场文书
婚内分居协议书范文
2014/11/26 职场文书
汽车4S店前台接待岗位职责
2015/04/03 职场文书
门球健将观后感
2015/06/16 职场文书
2019年入党思想汇报
2019/03/25 职场文书
导游词之香港-太平山顶
2019/10/18 职场文书
Mysql8.0递归查询的简单用法示例
2021/08/04 MySQL
Ruby GDBM操作简介及数据存储原理
2022/04/19 Ruby
MySQL主从切换的超详细步骤
2022/06/28 MySQL