TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程学习笔记(10):webpy框架
Jun 09 Python
python 实时遍历日志文件
Apr 12 Python
Python3下错误AttributeError: ‘dict’ object has no attribute’iteritems‘的分析与解决
Jul 06 Python
详解python多线程、锁、event事件机制的简单使用
Apr 27 Python
python画柱状图--不同颜色并显示数值的方法
Dec 13 Python
Python实现的读取文件内容并写入其他文件操作示例
Apr 09 Python
Pandas之ReIndex重新索引的实现
Jun 25 Python
python中pip的使用和修改下载源的方法
Jul 08 Python
Django中使用极验Geetest滑动验证码过程解析
Jul 31 Python
使用python快速实现不同机器间文件夹共享方式
Dec 22 Python
解决python 执行sql语句时所传参数含有单引号的问题
Jun 06 Python
python爬取企查查企业信息之selenium自动模拟登录企查查
Apr 08 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
用PHP和ACCESS写聊天室(六)
2006/10/09 PHP
php简单封装了一些常用JS操作
2007/02/25 PHP
PHP 数组基础知识小结
2010/08/20 PHP
php获取英文姓名首字母的方法
2015/07/13 PHP
浅析Yii2 GridView 日期格式化并实现日期可搜索教程
2016/04/22 PHP
window.parent调用父框架时 ie跟火狐不兼容问题
2009/07/30 Javascript
jquery tools 系列 scrollable(2)
2009/09/06 Javascript
JavaScript 笔记二 Array和Date对象方法
2010/05/22 Javascript
关于jQuery中的end()使用方法
2011/07/10 Javascript
jquery制作属于自己的select自定义样式
2015/11/23 Javascript
laravel5.3 vue 实现收藏夹功能实例详解
2018/01/21 Javascript
ES6关于Promise的用法详解
2018/05/07 Javascript
JavaScript+Canvas实现彩色图片转换成黑白图片的方法分析
2018/07/31 Javascript
微信小程序自定义组件封装及父子间组件传值的方法
2018/08/28 Javascript
解决vue点击控制单个样式的问题
2018/09/05 Javascript
jQuery实现消息弹出框效果
2019/12/10 jQuery
ES6学习笔记之let与const用法实例分析
2020/01/22 Javascript
Jquery滑动门/tab切换实现方法完整示例
2020/06/05 jQuery
解决vue init webpack 下载依赖卡住不动的问题
2020/11/09 Javascript
用python实现批量重命名文件的代码
2012/05/25 Python
python机器学习实战之树回归详解
2017/12/20 Python
python 实现在无序数组中找到中位数方法
2020/03/03 Python
CSS3截取字符串实例代码【推荐】
2018/06/07 HTML / CSS
详解通过focusout事件解决IOS键盘收起时界面不归位的问题
2019/07/18 HTML / CSS
Perfume’s Club中文官网:西班牙美妆在线零售品牌
2020/08/24 全球购物
材料员岗位职责
2014/03/13 职场文书
美容院经理岗位职责
2014/04/03 职场文书
社会调查研究计划书
2014/05/01 职场文书
党员群众路线个人整改措施思想汇报
2014/10/12 职场文书
公民授权委托书
2014/10/15 职场文书
工作年限证明范本
2015/06/15 职场文书
我在伊朗长大观后感
2015/06/16 职场文书
《秋天的怀念》教学反思
2016/02/17 职场文书
vue.js Router中嵌套路由的实用示例
2021/06/27 Vue.js
不负正版帝国之名 《重返帝国》引领SLG手游制作新的标杆
2022/04/07 其他游戏
python使用shell脚本创建kafka连接器
2022/04/29 Python