在Tensorflow中实现梯度下降法更新参数值


Posted in Python onJanuary 23, 2020

我就废话不多说了,直接上代码吧!

tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

TensorFlow经过使用梯度下降法对损失函数中的变量进行修改值,默认修改tf.Variable(tf.zeros([784,10]))

为Variable的参数。

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[w,b])

也可以使用var_list参数来定义更新那些参数的值

#导入Minst数据集
import input_data
mnist = input_data.read_data_sets("data",one_hot=True)
 
#导入tensorflow库
import tensorflow as tf
 
#输入变量,把28*28的图片变成一维数组(丢失结构信息)
x = tf.placeholder("float",[None,784])
 
#权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出
w = tf.Variable(tf.zeros([784,10]))
#偏置
b = tf.Variable(tf.zeros([10]))
 
#核心运算,其实就是softmax(x*w+b)
y = tf.nn.softmax(tf.matmul(x,w) + b)
 
#这个是训练集的正确结果
y_ = tf.placeholder("float",[None,10])
 
#交叉熵,作为损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
 
#梯度下降算法,最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
 
#初始化,在run之前必须进行的
init = tf.initialize_all_variables()
#创建session以便运算
sess = tf.Session()
sess.run(init)
 
#迭代1000次
for i in range(1000):
 #获取训练数据集的图片输入和正确表示数字
 batch_xs, batch_ys = mnist.train.next_batch(100)
 #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字
 sess.run(train_step,feed_dict = {x:batch_xs, y_: batch_ys})
 
#tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。
#这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。
#1代表正确,0代表错误
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
 
#tf.cast先将数据转换成float,防止求平均不准确。
#tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
#输出
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_: mnist.test.labels}))

计算结果如下

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:49:45.866600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:49:45.866600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.9163
 
Process finished with exit code 0

如果限制,只更新参数W查看效果

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:51:08.543600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:51:08.544600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.9187
 
Process finished with exit code 0

可以看出只修改W对结果影响不大,如果设置只修改b

#导入Minst数据集
import input_data
mnist = input_data.read_data_sets("data",one_hot=True)
 
#导入tensorflow库
import tensorflow as tf
 
#输入变量,把28*28的图片变成一维数组(丢失结构信息)
x = tf.placeholder("float",[None,784])
 
#权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出
w = tf.Variable(tf.zeros([784,10]))
#偏置
b = tf.Variable(tf.zeros([10]))
 
#核心运算,其实就是softmax(x*w+b)
y = tf.nn.softmax(tf.matmul(x,w) + b)
 
#这个是训练集的正确结果
y_ = tf.placeholder("float",[None,10])
 
#交叉熵,作为损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
 
#梯度下降算法,最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[b])
 
#初始化,在run之前必须进行的
init = tf.initialize_all_variables()
#创建session以便运算
sess = tf.Session()
sess.run(init)
 
#迭代1000次
for i in range(1000):
 #获取训练数据集的图片输入和正确表示数字
 batch_xs, batch_ys = mnist.train.next_batch(100)
 #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字
 sess.run(train_step,feed_dict = {x:batch_xs, y_: batch_ys})
 
#tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。
#这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。
#1代表正确,0代表错误
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
 
#tf.cast先将数据转换成float,防止求平均不准确。
#tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
#输出
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_: mnist.test.labels}))

计算结果:

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:52:04.483600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:52:04.483600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.1135
 
Process finished with exit code 0

如果只更新b那么对效果影响很大。

以上这篇在Tensorflow中实现梯度下降法更新参数值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python远程登录代码
Apr 29 Python
python 正则式 概述及常用字符
May 07 Python
Python获取DLL和EXE文件版本号的方法
Mar 10 Python
疯狂上涨的Python 开发者应从2.x还是3.x着手?
Nov 16 Python
解决Python 爬虫URL中存在中文或特殊符号无法请求的问题
May 11 Python
python for 循环获取index索引的方法
Feb 01 Python
Python大数据之使用lxml库解析html网页文件示例
Nov 16 Python
基于python+selenium的二次封装的实现
Jan 06 Python
Python终端输出彩色字符方法详解
Feb 11 Python
Python实现像awk一样分割字符串
Sep 15 Python
浅析python 字典嵌套
Sep 29 Python
python分布式爬虫中消息队列知识点详解
Nov 26 Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
You might like
PHP数据缓存技术
2007/02/14 PHP
php 获取完整url地址
2008/12/20 PHP
php生成的html meta和link标记在body标签里 顶部有个空行
2010/05/18 PHP
基于php常用正则表达式的整理汇总
2013/06/08 PHP
phpmyadmin配置文件现在需要绝密的短密码(blowfish_secret)的2种解决方法
2014/05/07 PHP
PHP后期静态绑定之self::限制实例分析
2018/12/21 PHP
零基础php编程好学吗
2019/10/11 PHP
PHPStorm 2020.1 调试 Nodejs的多种方法详解
2020/09/17 NodeJs
JQuery魔力之$("tagName")与selector
2012/03/05 Javascript
document.write()及其输出内容的样式、位置控制
2013/08/12 Javascript
javascript中typeof的使用示例
2013/12/19 Javascript
js操作iframe父子窗体示例
2014/05/22 Javascript
简介JavaScript中setUTCSeconds()方法的使用
2015/06/12 Javascript
jQuery使用模式窗口实现在主页面和子页面中互相传值的方法
2016/03/01 Javascript
Bootstrap源码学习笔记之bootstrap进度条
2016/12/24 Javascript
Easyui ueditor 整合解决不能编辑的问题(推荐)
2017/06/25 Javascript
Vue数据监听方法watch的使用
2018/03/28 Javascript
vue动画打包后失效问题的解决方法
2018/09/18 Javascript
angular中子控制器向父控制器传值的实例
2018/10/08 Javascript
layui 表单标签的校验方法
2019/09/04 Javascript
基于layui框架响应式布局的一些使用详解
2019/09/16 Javascript
vue 实现图片懒加载功能
2020/12/31 Vue.js
通过Python爬虫代理IP快速增加博客阅读量
2016/12/14 Python
网红编程语言Python将纳入高考你怎么看?
2018/06/07 Python
python 顺时针打印矩阵的超简洁代码
2018/11/14 Python
Django发送邮件和itsdangerous模块的配合使用解析
2019/08/10 Python
django echarts饼图数据动态加载的实例
2019/08/12 Python
python调用接口的4种方式代码实例
2019/11/19 Python
pytorch torch.nn.AdaptiveAvgPool2d()自适应平均池化函数详解
2020/01/03 Python
Python使用OpenPyXL处理Excel表格
2020/07/02 Python
趣天网日本站:Qoo10 JP
2019/09/18 全球购物
小学先进集体事迹材料
2014/05/31 职场文书
格列佛游记读书笔记
2015/06/30 职场文书
2016大学先进团支部事迹材料
2016/03/01 职场文书
2016年公共机构节能宣传周活动总结
2016/04/05 职场文书
Python可视化神器pyecharts之绘制箱形图
2022/07/07 Python