TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本实现分析dns日志并对受访域名排行
Sep 18 Python
django模型中的字段和model名显示为中文小技巧分享
Nov 18 Python
python使用pil生成缩略图的方法
Mar 26 Python
django用户注册、登录、注销和用户扩展的示例
Mar 19 Python
使用Python横向合并excel文件的实例
Dec 11 Python
pycharm中显示CSS提示的知识点总结
Jul 29 Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 Python
Keras 利用sklearn的ROC-AUC建立评价函数详解
Jun 15 Python
最新PyCharm从安装到PyCharm永久激活再到PyCharm官方中文汉化详细教程
Nov 17 Python
selenium判断元素是否存在的两种方法小结
Dec 07 Python
浅析python连接数据库的重要事项
Feb 22 Python
python Django框架快速入门教程(后台管理)
Jul 21 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
PHP隐形一句话后门,和ThinkPHP框架加密码程序(base64_decode)
2011/11/02 PHP
PHP使用DES进行加密与解密的方法详解
2013/06/06 PHP
淘宝ip地址查询类分享(利用淘宝ip库)
2014/01/07 PHP
php小技巧之过滤ascii控制字符
2014/05/14 PHP
php实现mysql事务处理的方法
2014/12/25 PHP
[原创]php常用字符串输出方法分析(echo,print,printf及sprintf)
2016/07/09 PHP
PHP 获取指定地区的天气实例代码
2017/02/08 PHP
php从身份证获取性别和出生年月
2017/02/09 PHP
php实现的简单多进程服务器类完整示例
2020/02/01 PHP
Asp.net下利用Jquery Ajax实现用户注册检测(验证用户名是否存)
2010/09/12 Javascript
js open() 与showModalDialog()方法使用介绍
2013/09/10 Javascript
Jquery EasyUI的添加,修改,删除,查询等基本操作介绍
2013/10/11 Javascript
基于jQuery实现最基本的淡入淡出效果实例
2015/02/02 Javascript
Angular 4 指令快速入门教程
2017/06/07 Javascript
详解React Native顶|底部导航使用小技巧
2017/09/14 Javascript
脚手架vue-cli工程webpack的基本用法详解
2018/09/29 Javascript
vuex state中的数组变化监听实例
2019/11/06 Javascript
JavaScript代码模拟鼠标自动点击事件示例
2020/08/07 Javascript
vue组件中节流函数的失效的原因和解决方法
2020/12/02 Vue.js
使用python opencv对目录下图片进行去重的方法
2019/01/12 Python
使用python实现unix2dos和dos2unix命令的例子
2019/08/13 Python
使用Django搭建一个基金模拟交易系统教程
2019/11/18 Python
PyCharm下载和安装详细步骤
2019/12/17 Python
Python 音频生成器的实现示例
2019/12/24 Python
python 瀑布线指标编写实例
2020/06/03 Python
解决pip安装的第三方包在PyCharm无法导入的问题
2020/10/15 Python
盖尔斯工厂店:GUESS Factory
2020/01/21 全球购物
点菜员岗位职责范本
2014/02/14 职场文书
资金主管岗位职责范本
2014/03/04 职场文书
小学生爱国演讲稿
2014/04/25 职场文书
基层党建工作汇报材料
2014/08/15 职场文书
2015年圣诞节活动总结
2015/03/24 职场文书
寒假生活随笔
2015/08/15 职场文书
小学记事作文之200字
2019/08/06 职场文书
Golang原生rpc(rpc服务端源码解读)
2022/04/07 Golang
Python爬虫 简单介绍一下Xpath及使用
2022/04/26 Python