TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python爬虫入门教程之点点美女图片爬虫代码分享
Sep 02 Python
Python语言实现机器学习的K-近邻算法
Jun 11 Python
Python中operator模块的操作符使用示例总结
Jun 28 Python
python中星号变量的几种特殊用法
Sep 07 Python
Python实现的爬虫功能代码
Jun 24 Python
用Django写天气预报查询网站
Oct 21 Python
使用urllib库的urlretrieve()方法下载网络文件到本地的方法
Dec 19 Python
Python程序打包工具py2exe和PyInstaller详解
Jun 28 Python
python实现两个dict合并与计算操作示例
Jul 01 Python
Python发送邮件的实例代码讲解
Oct 16 Python
Python中的sys.stdout.write实现打印刷新功能
Feb 21 Python
python语言中pandas字符串分割str.split()函数
Aug 05 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
php 模拟POST|GET操作实现代码
2010/07/20 PHP
ThinkPHP中__initialize()和类的构造函数__construct()用法分析
2014/11/29 PHP
PHP如何使用Memcached
2016/04/05 PHP
php如何计算两坐标点之间的距离
2018/12/29 PHP
搭建PhpStorm+PhpStudy开发环境的超详细教程
2020/09/17 PHP
Prototype中dom对象方法汇总
2008/09/17 Javascript
JS将所有对象s的属性复制给对象r(原生js+jquery)
2014/01/25 Javascript
对于Form表单reset方法的新认识
2014/03/05 Javascript
Java Mybatis框架入门基础教程
2015/09/21 Javascript
搞定immutable.js详细说明
2016/05/02 Javascript
js省市县三级联动效果实例
2020/04/15 Javascript
jQuery自定义数值抽奖活动代码
2016/06/11 Javascript
基于bootstrap实现广告轮播带图片和文字效果
2016/07/22 Javascript
html5 canvas 详细使用教程
2017/01/20 Javascript
JS中showModalDialog关闭子窗口刷新主窗口用法详解
2017/03/25 Javascript
js字符限制(字符截取) 一个中文汉字算两个字符
2017/09/12 Javascript
python调用新浪微博API项目实践
2014/07/28 Python
解决python3 安装不了PIL的问题
2019/08/16 Python
基于Python获取城市近7天天气预报
2019/11/26 Python
python tqdm 实现滚动条不上下滚动代码(保持一行内滚动)
2020/02/19 Python
Python自动化之UnitTest框架实战记录
2020/09/08 Python
Flask-SocketIO服务端安装及使用代码示例
2020/11/26 Python
美国设计师精美珠宝购物网:Netaya
2016/08/28 全球购物
欧洲有机婴儿食品最大的市场:Organic Baby Food(供美国和加拿大)
2018/03/28 全球购物
new修饰符是起什么作用
2015/06/28 面试题
园林设计师自荐信
2013/11/18 职场文书
县优秀教师事迹材料
2014/01/31 职场文书
事业单位分类改革实施方案
2014/03/21 职场文书
企业读书活动总结
2014/06/30 职场文书
营销计划书
2015/01/17 职场文书
销售员岗位职责范本
2015/04/11 职场文书
2015年妇联工作总结范文
2015/04/22 职场文书
2015年生产车间工作总结
2015/04/22 职场文书
Python实战之OpenCV实现猫脸检测
2021/06/26 Python
HTML5 新增内容和 API详解
2021/11/17 HTML / CSS
Python各协议下socket黏包问题原理
2022/04/12 Python