TensorFlow实现Softmax回归模型


Posted in Python onMarch 09, 2018

一、概述及完整代码

对MNIST(MixedNational Institute of Standard and Technology database)这个非常简单的机器视觉数据集,Tensorflow为我们进行了方便的封装,可以直接加载MNIST数据成我们期望的格式.本程序使用Softmax Regression训练手写数字识别的分类模型.

先看完整代码:

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
print(mnist.train.images.shape, mnist.train.labels.shape) 
print(mnist.test.images.shape, mnist.test.labels.shape) 
print(mnist.validation.images.shape, mnist.validation.labels.shape) 
 
#构建计算图 
x = tf.placeholder(tf.float32, [None, 784]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
y = tf.nn.softmax(tf.matmul(x, W) + b) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 
 
#在会话sess中启动图 
sess = tf.InteractiveSession() #创建InteractiveSession对象 
tf.global_variables_initializer().run() #全局参数初始化器 
for i in range(1000): 
 batch_xs, batch_ys = mnist.train.next_batch(100) 
 train_step.run({x: batch_xs, y_: batch_ys}) 
 
#测试验证阶段 
#沿着第1条轴方向取y和y_的最大值的索引并判断是否相等 
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
#转换bool型tensor为float32型tensor并求平均即得到正确率 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

二、详细解读

首先看一下使用TensorFlow进行算法设计训练的核心步骤

1.定义算法公式,也就是神经网络forward时的计算;

2.定义loss,选定优化器,并制定优化器优化loss;

3.在训练集上迭代训练算法模型;

4.在测试集或验证集上对训练得到的模型进行准确率评测.

首先创建一个Placeholder,即输入张量数据的地方,第一个参数是数据类型dtype,第二个参数是tensor的形状shape.接下来创建SoftmaxRegression模型中的weights(W)和biases(b)的Variable对象,不同于存储数据的tensor一旦使用掉就会消失,Variable在模型训练迭代中是持久存在的,并且在每轮迭代中被更新Variable初始化可以是常量或随机值.接下来实现模型算法y = softmax(Wx + b),TensorFlow语言只需要一行代码,tf.nn包含了大量神经网络的组件,头tf.matmul是矩阵乘法函数.TensorFlow将模型中的forward和backward的内容都自动实现,只要定义好loss,训练的时候会自动求导并进行梯度下降,完成对模型参数的自动学习.定义损失函数lossfunction来描述分类精度,对于多分类问题通常使用cross-entropy交叉熵.先定义一个placeholder输入真实的label,tf.reduce_sum和tf.reduce_mean的功能分别是求和和求平均.构造完损失函数cross-entropy后,再定义一个优化算法即可开始训练.我们采用随机梯度下降SGD,定义好后TensorFlow会自动添加许多运算操作来实现反向传播和梯度下降,而给我们提供的是一个封装好的优化器,只需要每轮迭代时feed数据给它就好.设置好学习率.

构造阶段完成后, 才能启动图. 启动图的第一步是创建一个 Session 对象或InteractiveSession对象, 如果无任何创建参数, 会话构造器将启动默认图.创建InteractiveSession对象会这个Session注册为默认的Session,之后的运算也默认跑在这个Session里面,不同Session之间的数据和运算应该是相互独立的.下一步使用TensorFlow的全局参数初始化器tf.global_variables_initializer病直接执行它的run方法(这个全局参数初始化器应该是1.0.0版本中的新特性,在之前0.10.0版本测试不通过).

至此,以上定义的所有公式其实只是Computation Graph,代码执行到这时,计算还没有实际发生,只有等调用run方法并feed数据时计算才真正执行.

随后一步,就可以开始迭代地执行训练操作train_step.这里每次都从训练集中随机抽取100条样本构成一个mini-batch,并feed给placeholder.

完成迭代训练后,就可以对模型的准确率进行验证.比较y和y_在各个测试样本中最大值所在的索引,然后转换为float32型tensor后求平均即可得到正确率.多次测试后得到在测试集上的正确率为92%左右.还是比较理想的结果.

三、其他补充

1.Sesssion类和InteractiveSession类

对于product =tf.matmul(matrix1, matrix2),调用 sess 的 'run()' 方法来执行矩阵乘法 op, 传入 'product' 作为该方法的参数.上面提到, 'product' 代表了矩阵乘法 op 的输出, 传入它是向方法表明, 我们希望取回矩阵乘法 op 的输出.整个执行过程是自动化的, 会话负责传递op 所需的全部输入. op 通常是并发执行的.函数调用 'run(product)' 触发了图中三个 op (两个常量 op 和一个矩阵乘法 op)的执行.返回值 'result' 是一个 numpy的`ndarray`对象.

Session 对象在使用完后需要关闭以释放资源sess.close(). 除了显式调用 close 外, 也可以使用"with" 代码块 来自动完成关闭动作.

with tf.Session() as sess: 
 result = sess.run([product]) 
 print result

为了便于使用诸如 IPython 之类的 Python 交互环境, 可以使用InteractiveSession代替 Session 类, 使用 Tensor.eval()和 Operation.run()方法代替 Session.run(). 这样可以避免使用一个变量来持有会话.

# 进入一个交互式 TensorFlow 会话. 
import tensorflow as tf 
sess = tf.InteractiveSession() 
x = tf.Variable([1.0, 2.0]) 
a = tf.constant([3.0, 3.0]) 
# 使用初始化器 initializer op 的 run() 方法初始化 'x' 
x.initializer.run() 
# 增加一个减法 sub op, 从 'x' 减去 'a'. 运行减法 op, 输出结果 
sub = tf.sub(x, a) 
print sub.eval() 
# ==> [-2. -1.]

2.tf.reduce_sum

首先,tf.reduce_X一系列运算操作(operation)是实现对一个tensor各种减少维度的数学计算.

tf.reduce_sum(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:沿着给定维度reduction_indices的方向降低input_tensor的维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1, 1, 1] 
#   [1, 1, 1]] 
tf.reduce_sum(x) ==> 6 
tf.reduce_sum(x, 0) ==> [2, 2, 2] 
tf.reduce_sum(x, 1) ==> [3, 3] 
tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]] 
tf.reduce_sum(x, [0, 1]) ==> 6

3.tf.reduce_mean

tf.reduce_mean(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:将input_tensor沿着给定维度reduction_indices减少维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1., 1. ] 
#   [2., 2.]] 
tf.reduce_mean(x) ==> 1.5 
tf.reduce_mean(x, 0) ==> [1.5, 1.5] 
tf.reduce_mean(x, 1) ==> [1., 2.]

4.tf.argmax

tf.argmax(input, dimension, name=None)

运算功能:返回input在指定维度下的最大值的索引.返回类型为int64.

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 中文字符串的处理实现代码
Oct 25 Python
Python实现统计英文单词个数及字符串分割代码
May 28 Python
python+mongodb数据抓取详细介绍
Oct 25 Python
python中Apriori算法实现讲解
Dec 10 Python
python email smtplib模块发送邮件代码实例
Apr 26 Python
Ubuntu下Anaconda和Pycharm配置方法详解
Jun 14 Python
Python进程池Pool应用实例分析
Nov 27 Python
Python面向对象原理与基础语法详解
Jan 02 Python
Python matplotlib绘制图形实例(包括点,曲线,注释和箭头)
Apr 17 Python
pytorch查看模型weight与grad方式
Jun 24 Python
python的数学算法函数及公式用法
Nov 18 Python
OpenCV+Python3.5 简易手势识别的实现
Dec 21 Python
用python实现百度翻译的示例代码
Mar 09 #Python
TensorFlow深度学习之卷积神经网络CNN
Mar 09 #Python
TensorFlow实现卷积神经网络CNN
Mar 09 #Python
新手常见6种的python报错及解决方法
Mar 09 #Python
Python 函数基础知识汇总
Mar 09 #Python
Python 使用with上下文实现计时功能
Mar 09 #Python
TensorFlow搭建神经网络最佳实践
Mar 09 #Python
You might like
PHP5 面向对象(学习记录)
2009/12/02 PHP
php表单请求获得数据求和示例
2014/05/15 PHP
php数组函数array_walk用法示例
2016/05/26 PHP
ThinkPHP3.2框架使用addAll()批量插入数据的方法
2017/03/16 PHP
Laravel框架实现redis集群的方法分析
2017/09/14 PHP
IE下使用cloneNode注意事项分享
2012/11/22 Javascript
js导出格式化的excel 实例方法
2013/07/17 Javascript
javaScript arguments 对象使用介绍
2013/10/18 Javascript
jQuery学习笔记之toArray()
2014/06/09 Javascript
jQuery中nextAll()方法用法实例
2015/01/07 Javascript
javascript实现去除HTML标签的方法
2016/12/26 Javascript
JavaScript中localStorage对象存储方式实例分析
2017/01/12 Javascript
bootstrap常用组件之头部导航实现代码
2017/04/20 Javascript
在vue项目中引入高德地图及其UI组件的方法
2018/09/04 Javascript
python 实时遍历日志文件
2016/04/12 Python
Python中index()和seek()的用法(详解)
2017/04/27 Python
Python安装图文教程 Pycharm安装教程
2018/03/27 Python
Python异常处理操作实例详解
2018/05/10 Python
python实现停车管理系统
2018/11/30 Python
Python数据类型之Tuple元组实例详解
2019/05/08 Python
python中栈的原理及实现方法示例
2019/11/27 Python
什么是python的id函数
2020/06/11 Python
Python实现播放和录制声音的功能
2020/08/12 Python
移动端解决悬浮层(悬浮header、footer)会遮挡住内容的3种方法
2015/03/27 HTML / CSS
深入解析HTML5 Canvas控制图形矩阵变换的方法
2016/03/24 HTML / CSS
德国家具购物网站:Möbel Höffner
2019/08/26 全球购物
vue路由实现登录拦截
2021/03/24 Vue.js
零件设计自荐信范文
2013/11/27 职场文书
个人思想理论学习的自我鉴定
2013/11/30 职场文书
函授毕业自我鉴定
2014/02/04 职场文书
求职信怎么写范文
2014/05/26 职场文书
拾金不昧表扬信怎么写
2015/05/04 职场文书
消费者理赔投诉书
2015/07/02 职场文书
演讲稿之我的初心我的成长
2019/08/12 职场文书
浅谈哪个Python库才最适合做数据可视化
2021/06/28 Python
Mysql排序的特性详情
2021/11/01 MySQL