TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现dnspod自动更新dns解析的方法
Feb 14 Python
在Windows8上的搭建Python和Django环境
Jul 03 Python
零基础写python爬虫之爬虫框架Scrapy安装配置
Nov 06 Python
Python下载懒人图库JavaScript特效
May 28 Python
浅谈Python中的可迭代对象、迭代器、For循环工作机制、生成器
Mar 11 Python
Python常见的pandas用法demo示例
Mar 16 Python
Python实现微信好友的数据分析
Dec 16 Python
基于python实现微信好友数据分析(简单)
Feb 16 Python
PyInstaller的安装和使用的详细步骤
Jun 02 Python
python使用建议技巧分享(三)
Aug 18 Python
Python实现像awk一样分割字符串
Sep 15 Python
python3中TQDM库安装及使用详解
Nov 18 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
php 图片上传类代码
2009/07/17 PHP
php 过滤器实现代码
2010/08/09 PHP
PHP中strncmp()函数比较两个字符串前2个字符是否相等的方法
2016/01/07 PHP
CI框架扩展系统核心类的方法分析
2016/05/23 PHP
PHP url的pathinfo模式加载不同控制器的简单实现
2016/08/12 PHP
php实现将数据做成json的格式给前端使用
2018/08/21 PHP
PHP压缩图片功能的介绍
2019/03/21 PHP
js判断变量是否空值的代码
2008/10/26 Javascript
jquery 获取自定义属性(attr和prop)的实现代码
2012/06/27 Javascript
js添加select下默认的option的value和text的方法
2014/10/19 Javascript
node.js中的定时器nextTick()和setImmediate()区别分析
2014/11/26 Javascript
深入解析JavaScript中函数的Currying柯里化
2016/03/19 Javascript
关于JS中setTimeout()无法调用带参函数问题的解决方法
2016/06/21 Javascript
JS控制HTML元素的显示和隐藏的两种方法
2016/09/27 Javascript
JS重载实现方法分析
2016/12/16 Javascript
解决vue 更改计算属性后select选中值不更改的问题
2018/03/02 Javascript
Vue props用法详解(小结)
2018/07/03 Javascript
从理论角度讨论JavaScript闭包
2019/04/03 Javascript
微信小程序开发打开另一个小程序的实现方法
2020/05/17 Javascript
解决Echarts 显示隐藏后宽度高度变小的问题
2020/07/19 Javascript
python机器学习之决策树分类详解
2017/12/20 Python
在PyCharm下打包*.py程序成.exe的方法
2018/11/29 Python
Python如何实现自带HTTP文件传输服务
2020/07/08 Python
python raise的基本使用
2020/09/10 Python
通过代码实例了解Python sys模块
2020/09/14 Python
HTML5实现文件断点续传的方法
2017/01/04 HTML / CSS
美国最大的香水出口:FragranceX.com
2017/11/04 全球购物
EntityManager都有哪些方法
2013/11/01 面试题
财务管理专业自荐信范文
2013/12/24 职场文书
军训心得体会
2013/12/31 职场文书
初中班级口号
2014/06/09 职场文书
出国签证在职证明
2014/09/20 职场文书
综合实践活动报告
2015/02/05 职场文书
对外汉语教师推荐信
2015/03/27 职场文书
Golang 正则匹配效率详解
2021/04/25 Golang
Rust中的Struct使用示例详解
2022/08/14 Javascript