TensorFlow实现Softmax回归模型


Posted in Python onMarch 09, 2018

一、概述及完整代码

对MNIST(MixedNational Institute of Standard and Technology database)这个非常简单的机器视觉数据集,Tensorflow为我们进行了方便的封装,可以直接加载MNIST数据成我们期望的格式.本程序使用Softmax Regression训练手写数字识别的分类模型.

先看完整代码:

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
print(mnist.train.images.shape, mnist.train.labels.shape) 
print(mnist.test.images.shape, mnist.test.labels.shape) 
print(mnist.validation.images.shape, mnist.validation.labels.shape) 
 
#构建计算图 
x = tf.placeholder(tf.float32, [None, 784]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
y = tf.nn.softmax(tf.matmul(x, W) + b) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 
 
#在会话sess中启动图 
sess = tf.InteractiveSession() #创建InteractiveSession对象 
tf.global_variables_initializer().run() #全局参数初始化器 
for i in range(1000): 
 batch_xs, batch_ys = mnist.train.next_batch(100) 
 train_step.run({x: batch_xs, y_: batch_ys}) 
 
#测试验证阶段 
#沿着第1条轴方向取y和y_的最大值的索引并判断是否相等 
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
#转换bool型tensor为float32型tensor并求平均即得到正确率 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

二、详细解读

首先看一下使用TensorFlow进行算法设计训练的核心步骤

1.定义算法公式,也就是神经网络forward时的计算;

2.定义loss,选定优化器,并制定优化器优化loss;

3.在训练集上迭代训练算法模型;

4.在测试集或验证集上对训练得到的模型进行准确率评测.

首先创建一个Placeholder,即输入张量数据的地方,第一个参数是数据类型dtype,第二个参数是tensor的形状shape.接下来创建SoftmaxRegression模型中的weights(W)和biases(b)的Variable对象,不同于存储数据的tensor一旦使用掉就会消失,Variable在模型训练迭代中是持久存在的,并且在每轮迭代中被更新Variable初始化可以是常量或随机值.接下来实现模型算法y = softmax(Wx + b),TensorFlow语言只需要一行代码,tf.nn包含了大量神经网络的组件,头tf.matmul是矩阵乘法函数.TensorFlow将模型中的forward和backward的内容都自动实现,只要定义好loss,训练的时候会自动求导并进行梯度下降,完成对模型参数的自动学习.定义损失函数lossfunction来描述分类精度,对于多分类问题通常使用cross-entropy交叉熵.先定义一个placeholder输入真实的label,tf.reduce_sum和tf.reduce_mean的功能分别是求和和求平均.构造完损失函数cross-entropy后,再定义一个优化算法即可开始训练.我们采用随机梯度下降SGD,定义好后TensorFlow会自动添加许多运算操作来实现反向传播和梯度下降,而给我们提供的是一个封装好的优化器,只需要每轮迭代时feed数据给它就好.设置好学习率.

构造阶段完成后, 才能启动图. 启动图的第一步是创建一个 Session 对象或InteractiveSession对象, 如果无任何创建参数, 会话构造器将启动默认图.创建InteractiveSession对象会这个Session注册为默认的Session,之后的运算也默认跑在这个Session里面,不同Session之间的数据和运算应该是相互独立的.下一步使用TensorFlow的全局参数初始化器tf.global_variables_initializer病直接执行它的run方法(这个全局参数初始化器应该是1.0.0版本中的新特性,在之前0.10.0版本测试不通过).

至此,以上定义的所有公式其实只是Computation Graph,代码执行到这时,计算还没有实际发生,只有等调用run方法并feed数据时计算才真正执行.

随后一步,就可以开始迭代地执行训练操作train_step.这里每次都从训练集中随机抽取100条样本构成一个mini-batch,并feed给placeholder.

完成迭代训练后,就可以对模型的准确率进行验证.比较y和y_在各个测试样本中最大值所在的索引,然后转换为float32型tensor后求平均即可得到正确率.多次测试后得到在测试集上的正确率为92%左右.还是比较理想的结果.

三、其他补充

1.Sesssion类和InteractiveSession类

对于product =tf.matmul(matrix1, matrix2),调用 sess 的 'run()' 方法来执行矩阵乘法 op, 传入 'product' 作为该方法的参数.上面提到, 'product' 代表了矩阵乘法 op 的输出, 传入它是向方法表明, 我们希望取回矩阵乘法 op 的输出.整个执行过程是自动化的, 会话负责传递op 所需的全部输入. op 通常是并发执行的.函数调用 'run(product)' 触发了图中三个 op (两个常量 op 和一个矩阵乘法 op)的执行.返回值 'result' 是一个 numpy的`ndarray`对象.

Session 对象在使用完后需要关闭以释放资源sess.close(). 除了显式调用 close 外, 也可以使用"with" 代码块 来自动完成关闭动作.

with tf.Session() as sess: 
 result = sess.run([product]) 
 print result

为了便于使用诸如 IPython 之类的 Python 交互环境, 可以使用InteractiveSession代替 Session 类, 使用 Tensor.eval()和 Operation.run()方法代替 Session.run(). 这样可以避免使用一个变量来持有会话.

# 进入一个交互式 TensorFlow 会话. 
import tensorflow as tf 
sess = tf.InteractiveSession() 
x = tf.Variable([1.0, 2.0]) 
a = tf.constant([3.0, 3.0]) 
# 使用初始化器 initializer op 的 run() 方法初始化 'x' 
x.initializer.run() 
# 增加一个减法 sub op, 从 'x' 减去 'a'. 运行减法 op, 输出结果 
sub = tf.sub(x, a) 
print sub.eval() 
# ==> [-2. -1.]

2.tf.reduce_sum

首先,tf.reduce_X一系列运算操作(operation)是实现对一个tensor各种减少维度的数学计算.

tf.reduce_sum(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:沿着给定维度reduction_indices的方向降低input_tensor的维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1, 1, 1] 
#   [1, 1, 1]] 
tf.reduce_sum(x) ==> 6 
tf.reduce_sum(x, 0) ==> [2, 2, 2] 
tf.reduce_sum(x, 1) ==> [3, 3] 
tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]] 
tf.reduce_sum(x, [0, 1]) ==> 6

3.tf.reduce_mean

tf.reduce_mean(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:将input_tensor沿着给定维度reduction_indices减少维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1., 1. ] 
#   [2., 2.]] 
tf.reduce_mean(x) ==> 1.5 
tf.reduce_mean(x, 0) ==> [1.5, 1.5] 
tf.reduce_mean(x, 1) ==> [1., 2.]

4.tf.argmax

tf.argmax(input, dimension, name=None)

运算功能:返回input在指定维度下的最大值的索引.返回类型为int64.

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中用memcached来减少数据库查询次数的教程
Apr 07 Python
浅析Python多线程下的变量问题
Apr 28 Python
wxpython中Textctrl回车事件无效的解决方法
Jul 21 Python
Python实现公历(阳历)转农历(阴历)的方法示例
Aug 22 Python
Python实现字符串匹配算法代码示例
Dec 05 Python
TensorFlow如何实现反向传播
Feb 06 Python
numpy使用fromstring创建矩阵的实例
Jun 15 Python
对pandas写入读取h5文件的方法详解
Dec 28 Python
Python Pillow Image Invert
Jan 22 Python
PYTHON实现SIGN签名的过程解析
Oct 28 Python
Python PyPDF2模块安装使用解析
Jan 19 Python
python 获取谷歌浏览器保存的密码
Jan 06 Python
用python实现百度翻译的示例代码
Mar 09 #Python
TensorFlow深度学习之卷积神经网络CNN
Mar 09 #Python
TensorFlow实现卷积神经网络CNN
Mar 09 #Python
新手常见6种的python报错及解决方法
Mar 09 #Python
Python 函数基础知识汇总
Mar 09 #Python
Python 使用with上下文实现计时功能
Mar 09 #Python
TensorFlow搭建神经网络最佳实践
Mar 09 #Python
You might like
php代码优化及php相关问题总结
2006/10/09 PHP
PHP 编程的 5个良好习惯
2009/02/20 PHP
php安全之直接用$获取值而不$_GET 字符转义
2012/06/03 PHP
深入php之规范编程命名小结
2013/05/15 PHP
探讨如何在php168_cms中提取验证码
2013/06/08 PHP
解决Yii2邮件发送结果返回成功,但接收不到邮件的问题
2017/05/23 PHP
jQuery学习2 选择器的使用说明
2010/02/07 Javascript
javascript针对DOM的应用分析(三)
2012/04/15 Javascript
用js判断页面刷新或关闭的方法(onbeforeunload与onunload事件)
2012/06/22 Javascript
jQuery获得内容和属性方法及示例
2013/12/02 Javascript
javaScript给元素添加多个class的简单实现
2016/07/20 Javascript
酷! 不同风格页面布局幻灯片特效js实现
2021/02/19 Javascript
js控住DOM实现发布微博效果
2016/08/30 Javascript
如何编写jquery插件
2017/03/29 jQuery
JS实现二叉查找树的建立以及一些遍历方法实现
2017/04/17 Javascript
Vue.js 实现微信公众号菜单编辑器功能(一)
2018/05/08 Javascript
nodejs 使用nodejs-websocket模块实现点对点实时通讯
2018/11/28 NodeJs
微信小程序工具函数封装
2019/10/28 Javascript
[01:20]辉夜杯背景故事宣传片《辉夜传说》
2015/12/25 DOTA
[01:04]DOTA2上海特锦赛现场采访 FreeAgain遭众解说围攻
2016/03/25 DOTA
python client使用http post 到server端的代码
2013/02/10 Python
Python SQLAlchemy基本操作和常用技巧(包含大量实例,非常好)
2014/05/06 Python
pycharm 使用心得(九)解决No Python interpreter selected的问题
2014/06/06 Python
跟老齐学Python之玩转字符串(3)
2014/09/14 Python
Python获取当前函数名称方法实例分享
2018/01/18 Python
[原创]Python入门教程2. 字符串基本操作【运算、格式化输出、常用函数】
2018/10/29 Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
2019/05/21 Python
python 利用turtle模块画出没有角的方格
2019/11/23 Python
Pytorch Tensor 输出为txt和mat格式方式
2020/01/03 Python
Selenium常见异常解析及解决方案示范
2020/04/10 Python
CSS3悬停效果案例应用
2012/11/21 HTML / CSS
HTML5 File接口在web页面上使用文件下载
2017/02/27 HTML / CSS
VLAN和VPN有什么区别?分别实现在OSI的第几层?
2014/12/23 面试题
酒店秘书求职信范文
2014/02/17 职场文书
2015年幼儿园中班开学寄语
2015/05/27 职场文书
Python数据可视化之用Matplotlib绘制常用图形
2021/06/03 Python