Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python中文乱码的解决方法
Nov 04 Python
python实现爬虫下载漫画示例
Feb 16 Python
解决pyqt中ui编译成窗体.py中文乱码的问题
Dec 23 Python
Python下载网络文本数据到本地内存的四种实现方法示例
Feb 05 Python
PyQt5每天必学之创建窗口居中效果
Apr 19 Python
Python基于百度AI的文字识别的示例
Apr 21 Python
详解Django之admin组件的使用和源码剖析
May 04 Python
python实现画一颗树和一片森林
Jun 25 Python
数据清洗--DataFrame中的空值处理方法
Jul 03 Python
Python 串口读写的实现方法
Jun 12 Python
python与pycharm有何区别
Jul 01 Python
Python读取xlsx数据生成图标代码实例
Aug 12 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
php 字符串函数收集
2010/03/29 PHP
yii使用activeFileField控件实现上传文件与图片的方法
2015/12/28 PHP
PHP实现负载均衡下的session共用功能
2018/04/17 PHP
对YUI扩展的Gird组件 Part-1
2007/03/10 Javascript
JavaScript 组件之旅(三):用 Ant 构建组件
2009/10/28 Javascript
JavaScript 常见对象类创建代码与优缺点分析
2009/12/07 Javascript
原生js仿jq判断当前浏览器是否为ie,精确到ie6~8
2014/08/30 Javascript
JS实现在线统计一个页面内鼠标点击次数的方法
2015/02/28 Javascript
Jquery attr()方法 属性赋值和属性获取详解
2016/04/15 Javascript
详解Vue项目中实现锚点定位
2019/04/24 Javascript
vue transition 在子组件中失效的解决
2019/11/12 Javascript
JS模拟实现京东快递单号查询
2020/11/30 Javascript
vant时间控件使用方法详解
2020/12/24 Javascript
SpringBoot+Vue 前后端合并部署的配置方法
2020/12/30 Vue.js
解决vue项目本地启动时无法携带cookie的问题
2021/02/06 Vue.js
[03:06]V社市场总监Dota2项目负责人Erik专访:希望更多中国玩家加入DOTA2
2014/07/11 DOTA
深入剖析Python的爬虫框架Scrapy的结构与运作流程
2016/01/20 Python
基于python3实现socket文件传输和校验
2018/07/28 Python
Pycharm取消py脚本中SQL识别的方法
2018/11/29 Python
解决python给列表里添加字典时被最后一个覆盖的问题
2019/01/21 Python
Flask框架模板渲染操作简单示例
2019/07/31 Python
python数组循环处理方法
2019/08/26 Python
OpenCV python sklearn随机超参数搜索的实现
2020/01/17 Python
解决Jupyter notebook中.py与.ipynb文件的import问题
2020/04/21 Python
Python3合并两个有序数组代码实例
2020/08/11 Python
FILA斐乐中国官方商城:意大利运动品牌
2017/01/25 全球购物
Java语言程序设计测试题选择题部分
2014/04/03 面试题
高级销售员求职信
2013/10/25 职场文书
大门门卫岗位职责
2013/11/30 职场文书
创新型城市实施方案
2014/03/06 职场文书
庆元旦文艺演出主持词
2014/03/27 职场文书
史学专业毕业生求职信
2014/05/09 职场文书
公共场所禁烟标语
2014/06/25 职场文书
教师个人考察材料
2014/12/16 职场文书
清明节主题班会
2015/08/14 职场文书
团支部组织委员竞选稿
2015/11/21 职场文书