在Tensorflow中实现梯度下降法更新参数值


Posted in Python onJanuary 23, 2020

我就废话不多说了,直接上代码吧!

tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

TensorFlow经过使用梯度下降法对损失函数中的变量进行修改值,默认修改tf.Variable(tf.zeros([784,10]))

为Variable的参数。

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[w,b])

也可以使用var_list参数来定义更新那些参数的值

#导入Minst数据集
import input_data
mnist = input_data.read_data_sets("data",one_hot=True)
 
#导入tensorflow库
import tensorflow as tf
 
#输入变量,把28*28的图片变成一维数组(丢失结构信息)
x = tf.placeholder("float",[None,784])
 
#权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出
w = tf.Variable(tf.zeros([784,10]))
#偏置
b = tf.Variable(tf.zeros([10]))
 
#核心运算,其实就是softmax(x*w+b)
y = tf.nn.softmax(tf.matmul(x,w) + b)
 
#这个是训练集的正确结果
y_ = tf.placeholder("float",[None,10])
 
#交叉熵,作为损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
 
#梯度下降算法,最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
 
#初始化,在run之前必须进行的
init = tf.initialize_all_variables()
#创建session以便运算
sess = tf.Session()
sess.run(init)
 
#迭代1000次
for i in range(1000):
 #获取训练数据集的图片输入和正确表示数字
 batch_xs, batch_ys = mnist.train.next_batch(100)
 #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字
 sess.run(train_step,feed_dict = {x:batch_xs, y_: batch_ys})
 
#tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。
#这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。
#1代表正确,0代表错误
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
 
#tf.cast先将数据转换成float,防止求平均不准确。
#tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
#输出
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_: mnist.test.labels}))

计算结果如下

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:49:45.866600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:49:45.866600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.9163
 
Process finished with exit code 0

如果限制,只更新参数W查看效果

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:51:08.543600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:51:08.544600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.9187
 
Process finished with exit code 0

可以看出只修改W对结果影响不大,如果设置只修改b

#导入Minst数据集
import input_data
mnist = input_data.read_data_sets("data",one_hot=True)
 
#导入tensorflow库
import tensorflow as tf
 
#输入变量,把28*28的图片变成一维数组(丢失结构信息)
x = tf.placeholder("float",[None,784])
 
#权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出
w = tf.Variable(tf.zeros([784,10]))
#偏置
b = tf.Variable(tf.zeros([10]))
 
#核心运算,其实就是softmax(x*w+b)
y = tf.nn.softmax(tf.matmul(x,w) + b)
 
#这个是训练集的正确结果
y_ = tf.placeholder("float",[None,10])
 
#交叉熵,作为损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
 
#梯度下降算法,最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[b])
 
#初始化,在run之前必须进行的
init = tf.initialize_all_variables()
#创建session以便运算
sess = tf.Session()
sess.run(init)
 
#迭代1000次
for i in range(1000):
 #获取训练数据集的图片输入和正确表示数字
 batch_xs, batch_ys = mnist.train.next_batch(100)
 #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字
 sess.run(train_step,feed_dict = {x:batch_xs, y_: batch_ys})
 
#tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。
#这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。
#1代表正确,0代表错误
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
 
#tf.cast先将数据转换成float,防止求平均不准确。
#tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
#输出
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_: mnist.test.labels}))

计算结果:

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:52:04.483600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:52:04.483600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.1135
 
Process finished with exit code 0

如果只更新b那么对效果影响很大。

以上这篇在Tensorflow中实现梯度下降法更新参数值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中List.index()方法的使用教程
May 20 Python
python简单实现计算过期时间的方法
Jun 09 Python
python中的lambda表达式用法详解
Jun 22 Python
Swift 3.0在集合类数据结构上的一些新变化总结
Jul 11 Python
Python3 Tkinter选择路径功能的实现方法
Jun 14 Python
python的一些加密方法及python 加密模块
Jul 11 Python
浅谈pycharm使用及设置方法
Sep 09 Python
Python 使用type来定义类的实现
Nov 19 Python
Python基于callable函数检测对象是否可被调用
Oct 16 Python
Python中openpyxl实现vlookup函数的实例
Oct 28 Python
如何基于Python爬虫爬取美团酒店信息
Nov 03 Python
python Pexpect模块的使用
Dec 25 Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
You might like
overlord人气高涨,却被菲利普频繁举报,第四季很难在国内上映
2020/05/06 日漫
php获取地址栏信息的代码
2008/10/08 PHP
php中的PHP_EOL换行符详细解析
2013/10/26 PHP
分享php分页的功能模块
2015/06/16 PHP
两种php实现图片上传的方法
2016/01/22 PHP
Zend Framework教程之视图组件Zend_View用法详解
2016/03/05 PHP
PHP封装的数据库保存session功能类
2016/07/11 PHP
PHP微信红包生成代码分享
2016/10/06 PHP
js parentElement和offsetParent之间的区别
2010/03/23 Javascript
jQuery学习基础知识小结
2010/11/25 Javascript
input禁止键盘及中文输入,但可以点击
2014/02/13 Javascript
jQuery使用fadein方法实现渐出效果实例
2015/03/27 Javascript
JQuery boxy插件在IE中边角图片不显示问题的解决
2015/05/20 Javascript
详解JavaScript操作HTML DOM的基本方式
2015/10/21 Javascript
JS清除文本框内容离开在恢复及鼠标离开文本框时触发js的方法
2016/01/12 Javascript
bootstrap table表格插件使用详解
2017/05/08 Javascript
JS数组去重的6种方法完整实例
2018/12/08 Javascript
jQuery实现简单飞机大战
2020/07/05 jQuery
[03:09]DOTA2亚洲邀请赛 LGD战队出场宣传片
2015/02/07 DOTA
开源软件包和环境管理系统Anaconda的安装使用
2017/09/04 Python
详解用Python处理HTML转义字符的5种方式
2017/12/27 Python
Python实现统计英文文章词频的方法分析
2019/01/28 Python
详解Python 函数如何重载?
2019/04/23 Python
33个Python爬虫项目实战(推荐)
2019/07/08 Python
Python 整行读取文本方法并去掉readlines换行\n操作
2020/09/03 Python
解决python 执行shell命令无法获取返回值的问题
2020/12/05 Python
html5的canvas元素使用方法介绍(画矩形、画折线、圆形)
2014/04/14 HTML / CSS
英国网上花店:Bunches
2016/11/29 全球购物
英国银首饰公司:e&e Jewellery
2021/02/11 全球购物
SQL Server面试题
2016/10/17 面试题
土建资料员岗位职责
2014/01/04 职场文书
证婚人搞笑证婚词
2014/01/10 职场文书
静心口服夜广告词
2014/03/20 职场文书
企业总经理任命书
2014/06/05 职场文书
2015年元旦主持词开场白
2014/12/14 职场文书
公司环境卫生管理制度
2015/08/05 职场文书