TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python之wxPython菜单使用详解
Sep 28 Python
python 写入csv乱码问题解决方法
Oct 23 Python
基于使用paramiko执行远程linux主机命令(详解)
Oct 16 Python
python数字图像处理之骨架提取与分水岭算法
Apr 27 Python
用python建立两个Y轴的XY曲线图方法
Jul 08 Python
Django结合ajax进行页面实时更新的例子
Aug 12 Python
python中通过selenium简单操作及元素定位知识点总结
Sep 10 Python
python画微信表情符的实例代码
Oct 09 Python
python将音频进行变速的操作方法
Apr 08 Python
jupyter notebook 使用过程中python莫名崩溃的原因及解决方式
Apr 10 Python
OpenCV图片漫画效果的实现示例
Aug 18 Python
python super()函数的基本使用
Sep 10 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
WordPress中登陆后关闭登陆页面及设置用户不可见栏目
2015/12/31 PHP
Laravel 创建可以传递参数 Console服务的例子
2019/10/14 PHP
070823更新的一个[消息提示框]组件 兼容ie7
2007/08/29 Javascript
使用jQuery监听扫码枪输入并禁止手动输入的实现方法(推荐)
2017/03/21 jQuery
浅谈vue自定义全局组件并通过全局方法 Vue.use() 使用该组件
2017/12/07 Javascript
vue scroller返回页面记住滚动位置的实例代码
2018/01/29 Javascript
iview日期控件,双向绑定日期格式的方法
2018/03/15 Javascript
小程序获取周围IBeacon设备的方法
2018/10/31 Javascript
浅析Proxy可以优化vue的数据监听机制问题及实现思路
2018/11/29 Javascript
antd-DatePicker组件获取时间值,及相关设置方式
2020/10/27 Javascript
Vue中inheritAttrs的使用实例详解
2020/12/31 Vue.js
Python读写txt文本文件的操作方法全解析
2016/06/26 Python
python中append实例用法总结
2019/07/30 Python
Django命名URL和反向解析URL实现解析
2019/08/09 Python
Django 路由层URLconf的实现
2019/12/30 Python
python中提高pip install速度
2020/02/14 Python
Django models文件模型变更错误解决
2020/05/11 Python
一文解决django 2.2与mysql兼容性问题
2020/07/15 Python
BeautifulSoup获取指定class样式的div的实现
2020/12/07 Python
ASP.NET Core中的配置详解
2021/02/05 Python
HTML5 Canvas的事件处理介绍
2015/04/24 HTML / CSS
使用phonegap查找联系人的实现方法
2017/03/31 HTML / CSS
彪马土耳其官网:PUMA土耳其
2019/07/14 全球购物
枚举和一组预处理的#define有什么不同
2016/09/21 面试题
C#如何允许一个类被继承但是避免这个类的方法被重载?
2015/02/24 面试题
农行实习自我鉴定
2013/09/22 职场文书
影视动画专业个人的自我评价
2013/12/31 职场文书
毕业生自荐书
2014/02/02 职场文书
2014年保卫科工作总结
2014/12/05 职场文书
办公室管理规章制度
2015/08/04 职场文书
安全生产学习心得体会
2016/01/18 职场文书
Golang之sync.Pool使用详解
2021/05/06 Golang
IDEA使用SpringAssistant插件创建SpringCloud项目
2021/06/23 Java/Android
自从在 IDEA 中用了热部署神器 JRebel 之后,开发效率提升了 10(真棒)
2021/06/26 Java/Android
MySQL into_Mysql中replace与replace into用法案例详解
2021/09/14 MySQL
python数字图像处理之对比度与亮度调整示例
2022/06/28 Python