TensorFlow实现Softmax回归模型


Posted in Python onMarch 09, 2018

一、概述及完整代码

对MNIST(MixedNational Institute of Standard and Technology database)这个非常简单的机器视觉数据集,Tensorflow为我们进行了方便的封装,可以直接加载MNIST数据成我们期望的格式.本程序使用Softmax Regression训练手写数字识别的分类模型.

先看完整代码:

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
print(mnist.train.images.shape, mnist.train.labels.shape) 
print(mnist.test.images.shape, mnist.test.labels.shape) 
print(mnist.validation.images.shape, mnist.validation.labels.shape) 
 
#构建计算图 
x = tf.placeholder(tf.float32, [None, 784]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
y = tf.nn.softmax(tf.matmul(x, W) + b) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 
 
#在会话sess中启动图 
sess = tf.InteractiveSession() #创建InteractiveSession对象 
tf.global_variables_initializer().run() #全局参数初始化器 
for i in range(1000): 
 batch_xs, batch_ys = mnist.train.next_batch(100) 
 train_step.run({x: batch_xs, y_: batch_ys}) 
 
#测试验证阶段 
#沿着第1条轴方向取y和y_的最大值的索引并判断是否相等 
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
#转换bool型tensor为float32型tensor并求平均即得到正确率 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

二、详细解读

首先看一下使用TensorFlow进行算法设计训练的核心步骤

1.定义算法公式,也就是神经网络forward时的计算;

2.定义loss,选定优化器,并制定优化器优化loss;

3.在训练集上迭代训练算法模型;

4.在测试集或验证集上对训练得到的模型进行准确率评测.

首先创建一个Placeholder,即输入张量数据的地方,第一个参数是数据类型dtype,第二个参数是tensor的形状shape.接下来创建SoftmaxRegression模型中的weights(W)和biases(b)的Variable对象,不同于存储数据的tensor一旦使用掉就会消失,Variable在模型训练迭代中是持久存在的,并且在每轮迭代中被更新Variable初始化可以是常量或随机值.接下来实现模型算法y = softmax(Wx + b),TensorFlow语言只需要一行代码,tf.nn包含了大量神经网络的组件,头tf.matmul是矩阵乘法函数.TensorFlow将模型中的forward和backward的内容都自动实现,只要定义好loss,训练的时候会自动求导并进行梯度下降,完成对模型参数的自动学习.定义损失函数lossfunction来描述分类精度,对于多分类问题通常使用cross-entropy交叉熵.先定义一个placeholder输入真实的label,tf.reduce_sum和tf.reduce_mean的功能分别是求和和求平均.构造完损失函数cross-entropy后,再定义一个优化算法即可开始训练.我们采用随机梯度下降SGD,定义好后TensorFlow会自动添加许多运算操作来实现反向传播和梯度下降,而给我们提供的是一个封装好的优化器,只需要每轮迭代时feed数据给它就好.设置好学习率.

构造阶段完成后, 才能启动图. 启动图的第一步是创建一个 Session 对象或InteractiveSession对象, 如果无任何创建参数, 会话构造器将启动默认图.创建InteractiveSession对象会这个Session注册为默认的Session,之后的运算也默认跑在这个Session里面,不同Session之间的数据和运算应该是相互独立的.下一步使用TensorFlow的全局参数初始化器tf.global_variables_initializer病直接执行它的run方法(这个全局参数初始化器应该是1.0.0版本中的新特性,在之前0.10.0版本测试不通过).

至此,以上定义的所有公式其实只是Computation Graph,代码执行到这时,计算还没有实际发生,只有等调用run方法并feed数据时计算才真正执行.

随后一步,就可以开始迭代地执行训练操作train_step.这里每次都从训练集中随机抽取100条样本构成一个mini-batch,并feed给placeholder.

完成迭代训练后,就可以对模型的准确率进行验证.比较y和y_在各个测试样本中最大值所在的索引,然后转换为float32型tensor后求平均即可得到正确率.多次测试后得到在测试集上的正确率为92%左右.还是比较理想的结果.

三、其他补充

1.Sesssion类和InteractiveSession类

对于product =tf.matmul(matrix1, matrix2),调用 sess 的 'run()' 方法来执行矩阵乘法 op, 传入 'product' 作为该方法的参数.上面提到, 'product' 代表了矩阵乘法 op 的输出, 传入它是向方法表明, 我们希望取回矩阵乘法 op 的输出.整个执行过程是自动化的, 会话负责传递op 所需的全部输入. op 通常是并发执行的.函数调用 'run(product)' 触发了图中三个 op (两个常量 op 和一个矩阵乘法 op)的执行.返回值 'result' 是一个 numpy的`ndarray`对象.

Session 对象在使用完后需要关闭以释放资源sess.close(). 除了显式调用 close 外, 也可以使用"with" 代码块 来自动完成关闭动作.

with tf.Session() as sess: 
 result = sess.run([product]) 
 print result

为了便于使用诸如 IPython 之类的 Python 交互环境, 可以使用InteractiveSession代替 Session 类, 使用 Tensor.eval()和 Operation.run()方法代替 Session.run(). 这样可以避免使用一个变量来持有会话.

# 进入一个交互式 TensorFlow 会话. 
import tensorflow as tf 
sess = tf.InteractiveSession() 
x = tf.Variable([1.0, 2.0]) 
a = tf.constant([3.0, 3.0]) 
# 使用初始化器 initializer op 的 run() 方法初始化 'x' 
x.initializer.run() 
# 增加一个减法 sub op, 从 'x' 减去 'a'. 运行减法 op, 输出结果 
sub = tf.sub(x, a) 
print sub.eval() 
# ==> [-2. -1.]

2.tf.reduce_sum

首先,tf.reduce_X一系列运算操作(operation)是实现对一个tensor各种减少维度的数学计算.

tf.reduce_sum(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:沿着给定维度reduction_indices的方向降低input_tensor的维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1, 1, 1] 
#   [1, 1, 1]] 
tf.reduce_sum(x) ==> 6 
tf.reduce_sum(x, 0) ==> [2, 2, 2] 
tf.reduce_sum(x, 1) ==> [3, 3] 
tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]] 
tf.reduce_sum(x, [0, 1]) ==> 6

3.tf.reduce_mean

tf.reduce_mean(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:将input_tensor沿着给定维度reduction_indices减少维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1., 1. ] 
#   [2., 2.]] 
tf.reduce_mean(x) ==> 1.5 
tf.reduce_mean(x, 0) ==> [1.5, 1.5] 
tf.reduce_mean(x, 1) ==> [1., 2.]

4.tf.argmax

tf.argmax(input, dimension, name=None)

运算功能:返回input在指定维度下的最大值的索引.返回类型为int64.

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的if、else、elif语句用法简明讲解
Mar 11 Python
python Socket之客户端和服务端握手详解
Sep 18 Python
手把手教你python实现SVM算法
Dec 27 Python
python自动结束mysql慢查询会话的实例代码
Oct 27 Python
详解Python list和numpy array的存储和读取方法
Nov 06 Python
pytorch .detach() .detach_() 和 .data用于切断反向传播的实现
Dec 27 Python
python能否java成为主流语言吗
Jun 22 Python
基于Python下载网络图片方法汇总代码实例
Jun 24 Python
django模型类中,null=True,blank=True用法说明
Jul 09 Python
利用Python实现自动扫雷小脚本
Dec 17 Python
详解使用python爬取抖音app视频(appium可以操控手机)
Jan 26 Python
python前后端自定义分页器
Apr 13 Python
用python实现百度翻译的示例代码
Mar 09 #Python
TensorFlow深度学习之卷积神经网络CNN
Mar 09 #Python
TensorFlow实现卷积神经网络CNN
Mar 09 #Python
新手常见6种的python报错及解决方法
Mar 09 #Python
Python 函数基础知识汇总
Mar 09 #Python
Python 使用with上下文实现计时功能
Mar 09 #Python
TensorFlow搭建神经网络最佳实践
Mar 09 #Python
You might like
ueditor 1.2.6 使用方法说明
2013/07/24 PHP
FormValid0.5版本发布,带ajax自定义验证例子
2007/08/17 Javascript
javascript面向对象之Javascript 继承
2010/05/04 Javascript
JavaScript自定义DateDiff函数(兼容所有浏览器)
2012/03/01 Javascript
拉动滚动条加载数据的jquery代码
2012/05/03 Javascript
可以用鼠标拖动的DIV实现思路及代码
2013/10/21 Javascript
javascript相等运算符与等同运算符详细介绍
2013/11/09 Javascript
使用JavaScript获取电池状态的方法
2014/05/03 Javascript
JavaScript在Android的WebView中parseInt函数转换不正确问题解决方法
2015/04/25 Javascript
jQuery实现图片与文字描述左右滑动自动切换的方法
2015/07/27 Javascript
cocos2dx骨骼动画Armature源码剖析(三)
2015/09/08 Javascript
浅谈node.js中async异步编程
2015/10/22 Javascript
jQuery使用$.ajax进行异步刷新的方法(附demo下载)
2015/12/04 Javascript
辨析JavaScript中的Undefined类型与null类型
2016/05/26 Javascript
js学习阶段总结(必看篇)
2016/06/16 Javascript
JS实现用户注册时获取短信验证码和倒计时功能
2016/10/27 Javascript
HTML5canvas 绘制一个圆环形的进度表示实例
2016/12/16 Javascript
从零开始做一个pagination分页组件
2017/03/15 Javascript
JavaScript实现多叉树的递归遍历和非递归遍历算法操作示例
2018/02/08 Javascript
详解Eslint 配置及规则说明
2018/09/10 Javascript
Django+Vue实现WebSocket连接的示例代码
2019/05/28 Javascript
JavaScript 格式化数字、金额、千分位、保留几位小数、舍入舍去
2019/07/23 Javascript
[00:32]2018DOTA2亚洲邀请赛VG出场
2018/04/03 DOTA
python实现按行切分文本文件的方法
2016/04/18 Python
Python 专题一 函数的基础知识
2017/03/16 Python
python中文件变化监控示例(watchdog)
2017/10/16 Python
Python+matplotlib+numpy实现在不同平面的二维条形图
2018/01/02 Python
Pandas聚合运算和分组运算的实现示例
2019/10/17 Python
Django --Xadmin 判断登录者身份实例
2020/07/03 Python
PyQt中使用QtSql连接MySql数据库的方法
2020/07/28 Python
Python基于内置函数type创建新类型
2020/10/22 Python
教师自我评价范文
2013/12/16 职场文书
学习党的群众路线剖析材料
2014/10/09 职场文书
士兵突击观后感
2015/06/16 职场文书
英语教学课后反思
2016/02/15 职场文书
Java使用Unsafe类的示例详解
2021/09/25 Java/Android