TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python实现的可以拷贝或剪切一个文件列表中的所有文件
Apr 30 Python
基于scrapy实现的简单蜘蛛采集程序
Apr 17 Python
python读取word文档的方法
May 09 Python
使用Python将数组的元素导出到变量中(unpacking)
Oct 27 Python
每天迁移MySQL历史数据到历史库Python脚本
Apr 13 Python
python读取txt文件中特定位置字符的方法
Dec 24 Python
浅谈python 读excel数值为浮点型的问题
Dec 25 Python
python对矩阵进行转置的2种处理方法
Jul 17 Python
Python学习笔记之Break和Continue用法分析
Aug 14 Python
Python threading.local代码实例及原理解析
Mar 16 Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 Python
教你怎么用Python生成九宫格照片
May 20 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
如何从一个php文件向另一个地址post数据,不用表单和隐藏的变量的
2007/03/06 PHP
php获取错误信息的方法
2015/07/17 PHP
PHP面向对象程序设计之多态性的应用示例
2018/12/19 PHP
laravel开发环境homestead搭建过程详解
2020/07/03 PHP
javascript对数组的常用操作代码 数组方法总汇
2011/01/27 Javascript
js关闭父窗口时关闭子窗口
2013/04/01 Javascript
在js文件中如何获取basePath处理js路径问题
2013/07/10 Javascript
jQuery中removeClass()方法用法实例
2015/01/05 Javascript
前端框架Vue.js构建大型应用浅析
2016/09/12 Javascript
Sequelize中用group by进行分组聚合查询
2016/12/12 Javascript
Node.js中文件操作模块File System的详细介绍
2017/01/05 Javascript
jQuery实现单击按钮遮罩弹出对话框效果(2)
2017/02/20 Javascript
详解Vue 动态添加模板的几种方法
2017/04/25 Javascript
Vue Ajax跨域请求实例详解
2017/06/20 Javascript
react开发教程之React 组件之间的通信方式
2017/08/12 Javascript
Iphone手机、安卓手机浏览器控制默认缩放大小的方法总结(附代码)
2017/08/18 Javascript
10分钟彻底搞懂Http的强制缓存和协商缓存(小结)
2018/08/30 Javascript
JavaScript中的类型检查
2020/02/03 Javascript
koa-passport实现本地验证的方法示例
2020/02/20 Javascript
js获取图片的base64编码并压缩
2020/12/05 Javascript
Python并发编程协程(Coroutine)之Gevent详解
2017/12/27 Python
python实现判断一个字符串是否是合法IP地址的示例
2018/06/04 Python
Python实现删除某列中含有空值的行的示例代码
2020/07/20 Python
Python 发送邮件方法总结
2020/08/10 Python
生物技术专业研究生自荐信
2013/09/22 职场文书
生产部岗位职责范文
2014/02/07 职场文书
旅游市场营销方案
2014/03/09 职场文书
精彩的广告词
2014/03/19 职场文书
联谊活动总结
2014/08/28 职场文书
第二批党的群众路线教育实践活动个人整改方案
2014/10/31 职场文书
打架检讨书
2015/01/27 职场文书
2015年世界艾滋病日活动总结
2015/03/24 职场文书
地雷战观后感
2015/06/09 职场文书
使用CSS设置滚动条样式
2022/01/18 HTML / CSS
游戏《我的世界》澄清Xbox版暂无计划加入光追
2022/04/03 其他游戏
python和Appium的移动端多设备自动化测试框架
2022/04/26 Python