在Tensorflow中实现梯度下降法更新参数值


Posted in Python onJanuary 23, 2020

我就废话不多说了,直接上代码吧!

tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

TensorFlow经过使用梯度下降法对损失函数中的变量进行修改值,默认修改tf.Variable(tf.zeros([784,10]))

为Variable的参数。

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[w,b])

也可以使用var_list参数来定义更新那些参数的值

#导入Minst数据集
import input_data
mnist = input_data.read_data_sets("data",one_hot=True)
 
#导入tensorflow库
import tensorflow as tf
 
#输入变量,把28*28的图片变成一维数组(丢失结构信息)
x = tf.placeholder("float",[None,784])
 
#权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出
w = tf.Variable(tf.zeros([784,10]))
#偏置
b = tf.Variable(tf.zeros([10]))
 
#核心运算,其实就是softmax(x*w+b)
y = tf.nn.softmax(tf.matmul(x,w) + b)
 
#这个是训练集的正确结果
y_ = tf.placeholder("float",[None,10])
 
#交叉熵,作为损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
 
#梯度下降算法,最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
 
#初始化,在run之前必须进行的
init = tf.initialize_all_variables()
#创建session以便运算
sess = tf.Session()
sess.run(init)
 
#迭代1000次
for i in range(1000):
 #获取训练数据集的图片输入和正确表示数字
 batch_xs, batch_ys = mnist.train.next_batch(100)
 #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字
 sess.run(train_step,feed_dict = {x:batch_xs, y_: batch_ys})
 
#tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。
#这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。
#1代表正确,0代表错误
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
 
#tf.cast先将数据转换成float,防止求平均不准确。
#tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
#输出
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_: mnist.test.labels}))

计算结果如下

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:49:45.866600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:49:45.866600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.9163
 
Process finished with exit code 0

如果限制,只更新参数W查看效果

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:51:08.543600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:51:08.544600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.9187
 
Process finished with exit code 0

可以看出只修改W对结果影响不大,如果设置只修改b

#导入Minst数据集
import input_data
mnist = input_data.read_data_sets("data",one_hot=True)
 
#导入tensorflow库
import tensorflow as tf
 
#输入变量,把28*28的图片变成一维数组(丢失结构信息)
x = tf.placeholder("float",[None,784])
 
#权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出
w = tf.Variable(tf.zeros([784,10]))
#偏置
b = tf.Variable(tf.zeros([10]))
 
#核心运算,其实就是softmax(x*w+b)
y = tf.nn.softmax(tf.matmul(x,w) + b)
 
#这个是训练集的正确结果
y_ = tf.placeholder("float",[None,10])
 
#交叉熵,作为损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
 
#梯度下降算法,最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[b])
 
#初始化,在run之前必须进行的
init = tf.initialize_all_variables()
#创建session以便运算
sess = tf.Session()
sess.run(init)
 
#迭代1000次
for i in range(1000):
 #获取训练数据集的图片输入和正确表示数字
 batch_xs, batch_ys = mnist.train.next_batch(100)
 #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字
 sess.run(train_step,feed_dict = {x:batch_xs, y_: batch_ys})
 
#tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。
#这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。
#1代表正确,0代表错误
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
 
#tf.cast先将数据转换成float,防止求平均不准确。
#tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
#输出
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_: mnist.test.labels}))

计算结果:

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:52:04.483600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:52:04.483600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.1135
 
Process finished with exit code 0

如果只更新b那么对效果影响很大。

以上这篇在Tensorflow中实现梯度下降法更新参数值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现2048小游戏
Mar 30 Python
Python实现的简单模板引擎功能示例
Sep 02 Python
详解Python nose单元测试框架的安装与使用
Dec 20 Python
Python网页正文转换语音文件的操作方法
Dec 09 Python
Flask-WTF表单的使用方法
Jul 12 Python
Pytest mark使用实例及原理解析
Feb 22 Python
python用opencv完成图像分割并进行目标物的提取
May 25 Python
浅谈keras使用中val_acc和acc值不同步的思考
Jun 18 Python
Python实现爬取并分析电商评论
Jun 19 Python
Python实现粒子群算法的示例
Feb 14 Python
python中scipy.stats产生随机数实例讲解
Feb 19 Python
python办公自动化之excel的操作
May 23 Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
You might like
有关于PHP中常见数据类型的汇总分享
2014/01/06 PHP
PHP函数addslashes和mysql_real_escape_string的区别
2014/04/22 PHP
PHP获取时间排除周六、周日的两个方法
2014/06/30 PHP
php冒泡排序与快速排序实例详解
2015/12/07 PHP
PHP+Ajax验证码验证用户登录
2016/07/20 PHP
PHP jpgraph库的配置及生成统计图表:折线图、柱状图、饼状图
2017/05/15 PHP
php+js实现的拖动滑块验证码验证表单操作示例【附源码下载】
2020/05/27 PHP
javascript下过滤数组重复值的代码
2007/09/10 Javascript
Javascript 面试题随笔
2011/03/31 Javascript
浅析js封装和作用域
2013/07/09 Javascript
使用node.js半年来总结的 10 条经验
2014/08/18 Javascript
Javascript学习笔记之函数篇(四):arguments 对象
2014/11/23 Javascript
vue双向数据绑定原理探究(附demo)
2017/01/17 Javascript
angularJS深拷贝详解
2017/03/23 Javascript
HTML的select控件美化
2017/03/27 Javascript
JavaScript 中调用 Kotlin 方法实例详解
2017/06/09 Javascript
Angular2+如何去除url中的#号详解
2017/12/20 Javascript
JS通过位运算实现权限加解密
2018/08/14 Javascript
深入浅析Vue 中 ref 的使用
2019/04/29 Javascript
vue动态注册组件实例代码详解
2019/05/30 Javascript
javascript function(函数类型)使用与注意事项小结
2019/06/10 Javascript
vue实现图片上传预览功能
2019/12/23 Javascript
Python实现向QQ群成员自动发邮件的方法
2014/11/19 Python
Python运算符重载用法实例
2015/05/28 Python
python制作爬虫并将抓取结果保存到excel中
2016/04/06 Python
Python使用剪切板的方法
2017/06/06 Python
在OpenCV里实现条码区域识别的方法示例
2019/12/04 Python
python 实现让字典的value 成为列表
2019/12/16 Python
python中的错误如何查看
2020/07/08 Python
欧洲最大的化妆品连锁公司:Douglas道格拉斯
2017/05/06 全球购物
大四毕业生学习总结的自我评价
2013/10/31 职场文书
成人教育自我鉴定
2013/11/01 职场文书
老师对学生的寄语
2014/04/09 职场文书
2016年寒假见闻
2015/10/10 职场文书
2019员工保密协议书(3篇)
2019/09/23 职场文书
html输入两个数实现加减乘除功能
2021/07/01 HTML / CSS