TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.x中文乱码问题解决方法
Jun 02 Python
Python二叉搜索树与双向链表转换实现方法
Apr 29 Python
Python实现的概率分布运算操作示例
Aug 14 Python
python2.7到3.x迁移指南
Feb 01 Python
解决Tensorflow使用pip安装后没有model目录的问题
Jun 13 Python
Python Dataframe 指定多列去重、求差集的方法
Jul 10 Python
python 画三维图像 曲面图和散点图的示例
Dec 29 Python
python实现烟花小程序
Jan 30 Python
Python音频操作工具PyAudio上手教程详解
Jun 26 Python
TensorFlow实现打印每一层的输出
Jan 21 Python
python 实现Harris角点检测算法
Dec 11 Python
python爬虫scrapy框架之增量式爬虫的示例代码
Feb 26 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
PHP生成网站桌面快捷方式代码分享
2014/10/11 PHP
JavaScript的面向对象(二)
2006/11/09 Javascript
Javascript中eval函数的使用方法与示例
2007/04/09 Javascript
一些有用的JavaScript和jQuery的片段分享
2011/08/23 Javascript
jquery方法+js一般方法+js面向对象方法实现拖拽效果
2012/08/30 Javascript
node.js中的dns.getServers方法使用说明
2014/12/08 Javascript
jquery实现的省市区三级联动
2015/04/02 Javascript
js获取浏览器和屏幕的各种宽度高度
2017/02/22 Javascript
详谈javascript精度问题与调整
2017/07/08 Javascript
jQuery实现手势解锁密码特效
2017/08/14 jQuery
bootstrap tooltips在 angularJS中的使用方法
2019/04/10 Javascript
解决cordova+vue 项目打包成APK应用遇到的问题
2019/05/10 Javascript
在JavaScript中实现链式调用的实现
2019/12/24 Javascript
Vuex的各个模块封装的实现
2020/06/05 Javascript
Python实现过滤单个Android程序日志脚本分享
2015/01/16 Python
深入理解python中的select模块
2017/04/23 Python
使用python Telnet远程登录执行程序的方法
2019/01/26 Python
python调用c++传递数组的实例
2019/02/13 Python
python离线安装外部依赖包的实现
2020/02/13 Python
python3.7中安装paddleocr及paddlepaddle包的多种方法
2020/11/27 Python
国外平面设计素材网站:The Hungry JPEG
2017/03/28 全球购物
BIBLOO波兰:捷克的一家在线服装店
2018/03/09 全球购物
捷克鲜花配送:Florea.cz
2018/10/29 全球购物
写一个用矩形法求定积分的通用函数
2012/11/08 面试题
Linux如何命名文件--使用文件名时应注意
2012/01/22 面试题
艺术爱好者的自我评价分享
2013/10/08 职场文书
普通院校学生的自荐信
2013/11/27 职场文书
应用化学专业职业生涯规划书
2013/12/31 职场文书
写给女生的道歉信
2014/01/14 职场文书
市场部业务员岗位职责
2014/04/02 职场文书
无刑事犯罪记录证明范本
2014/09/29 职场文书
专题组织生活会思想汇报
2014/10/01 职场文书
中秋节慰问信
2015/02/15 职场文书
小学教师教学反思
2016/02/24 职场文书
Python如何配置环境变量详解
2021/05/18 Python
集英社今正式宣布 成立游戏公司“集英社Games”
2022/03/31 其他游戏