Python利用逻辑回归模型解决MNIST手写数字识别问题详解


Posted in Python onJanuary 14, 2020

本文实例讲述了Python利用逻辑回归模型解决MNIST手写数字识别问题。分享给大家供大家参考,具体如下:

1、MNIST手写识别问题

MNIST手写数字识别问题:输入黑白的手写阿拉伯数字,通过机器学习判断输入的是几。可以通过TensorFLow下载MNIST手写数据集,通过import引入MNIST数据集并进行读取,会自动从网上下载所需文件。

%matplotlib inline
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)
import matplotlib.pyplot as plt
 
def plot_image(image):                #图片显示函数
  plt.imshow(image.reshape(28,28),cmap='binary')
  plt.show()
 
print("训练集数量:",mnist.train.num_examples,
   "特征值组成:",mnist.train.images.shape,
   "标签组成:",mnist.train.labels.shape)
 
batch_images,batch_labels=mnist.train.next_batch(batch_size=10)  #批量读取数据
print(batch_images.shape,batch_labels.shape)
 
print('标签值:',np.argmax(mnist.train.labels[1000]),end=' ')  #np.argmax()得到实际值
print('独热编码表示:',mnist.train.labels[1000])
plot_image(mnist.train.images[1000])         #显示数据集中第1000张图片

Python利用逻辑回归模型解决MNIST手写数字识别问题详解Python利用逻辑回归模型解决MNIST手写数字识别问题详解

输出训练集 的数量有55000个,并打印特征值的shape为(55000,784),其中784代表每张图片由28*28个像素点组成,由于是黑白图片,每个像素点只有黑白单通道,即通过784个数可以描述一张图片的特征值。可以将图片在Jupyter中输出,将784个特征值reshape为28×28的二维数组,传给plt.imshow()函数,之后再通过show()输出。

MNIST提供next_batch()方法用于批量读取数据集,例如上面批量读取10个对应的images与labels数据并分别返回。该方法会按顺序一直往后读取,直到结束后会自动打乱数据,重新继续读取。

在打开mnist数据集时,第二个参数设置one_hot,表示采用独热编码方式打开。独热编码是一种稀疏向量,其中一个元素为1,其他元素均为0,常用于表示有限个可能的组合情况。例如数字6的独热编码为第7个分量为1,其他为0的数组。可以通过np.argmax()函数返回数组最大值的下标,即独热编码表示的实际数字。通过独热编码可以将离散特征的某个取值对应欧氏空间的某个点,有利于机器学习中特征之间的距离计算

数据集的划分,一种划分为训练集用于模型的训练,测试集用于结果的测试,要求集合数量足够大,而且具有代表性。但是在多次执行后,会导致模型向测试集数据进行拟合,从而导致测试集数据失去了测试的效果。因此将数据集进一步划分为训练集、验证集、测试集,将训练后的模型用验证集验证,当多次迭代结束之后再拿测试集去测试。MNIST数据集中的训练集为mnist.train,验证集为mnist.validation,测试集为mnist.test

2、逻辑回归

与线性回归相对比,房价预测是根据多个输入参数x与对应权重w相乘再加上b得到线性的输出房价。而还有许多问题的输出是非线性的、控制在[0,1]之间的,比如判断邮件是否为垃圾邮件,手写数字为0~9等,逻辑回归就是用于处理此类问题。例如电子邮件分类器输出0.8,表示该邮件为垃圾邮件的概率是0.8.

逻辑回归通过Sigmoid函数保证输出的值在[0,1]之间,该函数可以将全体实数映射到[0,1],从而将线性的输出转换为[0,1]的数。其定义与图像如下:

Python利用逻辑回归模型解决MNIST手写数字识别问题详解Python利用逻辑回归模型解决MNIST手写数字识别问题详解

在逻辑回归中如果采用均方差的损失函数,带入sigmoid会得到一个非凸函数,这类函数会有多个极小值,采用梯度下降法便无法求得最优解。因此在逻辑回归中采用对数损失函数Python利用逻辑回归模型解决MNIST手写数字识别问题详解,其中y是特征值x的标签,y'是预测值。

在手写数字识别中,通过单层神经元产生连续的输出值y,将y再输入到softmax层处理,经过函数计算将结果映射为0~9每个数字对应的概率,概率越大表示该图片越像某个数字,所有数字的概率之和为1

Python利用逻辑回归模型解决MNIST手写数字识别问题详解

交叉熵损失函数:交叉熵用于刻画两个概率分布之间的距离Python利用逻辑回归模型解决MNIST手写数字识别问题详解,其中p代表正确答案,q代表预测值,交叉熵越小距离越近,从而模型的预测越准确。例如正确答案为(1,0,0),甲模型预测为(0.5,0.2,0.3),其交叉熵=-1*log0.5≈0.3,乙模型(0.7,0.1,0.2),其交叉熵=-1*log0.7≈0.15,所以乙模型预测更准确

模型的训练

首先定义二维浮点数占位符x、y,以及二维参数变量W、b并随机赋初值。之后定义前向计算为向量x与W对应叉乘再加b,并将得到的线性结果经过softmax处理得到独热编码预测值。

之后定义准确率accuracy,其值为预测值pred与真实值y相等个数来衡量

接下来初始化变量、设置超参数,并定义损失函数、优化器,之后开始训练。每轮训练中分批次读取数据进行训练,每轮训练结束后输出损失与准确率。

import numpy as np
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)
import matplotlib.pyplot as plt
 
#定义占位符、变量、前向计算
x=tf.placeholder(tf.float32,[None,784],name='x')
y=tf.placeholder(tf.float32,[None,10],name='y')
W=tf.Variable(tf.random_normal([784,10]),name='W')
b=tf.Variable(tf.zeros([10]),name='b')
forward=tf.matmul(x,W)+b
pred=tf.nn.softmax(forward)               #通过softmax将线性结果分类处理
 
#计算预测值与真实值的匹配个数
correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
#将上一步得到的布尔值转换为浮点数,并求平均值,得到准确率
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
 
ss=tf.Session()
init=tf.global_variables_initializer()
ss.run(init)
 
#超参数设置
train_epochs=50
batch_size=100                        #每个批次的样本数
batch_num=int(mnist.train.num_examples/batch_size)      #一轮需要训练多少批
learning_rate=0.01
 
#定义交叉熵损失函数、梯度下降优化器
loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
 
for epoch in range(train_epochs):
  for batch in range(batch_num):              #分批次读取数据进行训练
    xs,ys=mnist.train.next_batch(batch_size)
    ss.run(optimizer,feed_dict={x:xs,y:ys})
  #每轮训练结束后通过带入验证集的数据,检测模型的损失与准去率 
  loss,acc=ss.run([loss_function,accuracy],\
          feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
  print('第%2d轮训练:损失为:%9f,准确率:%.4f'%(epoch+1,loss,acc))

从每轮训练结果可以看出损失在逐渐下降,准确率在逐步上升。

Python利用逻辑回归模型解决MNIST手写数字识别问题详解

结果预测

使用训练好的模型对测试集中的数据进行预测,即将mnist.test.images数据带入去求pred的值。

为了使结果更便于显示,可以借助plot函数库将图片数据显示出来,并配以文字label与predic的值。首先通过plt.gcf()得到一副图像资源并设置其大小。再通过plt.subplot(5,5,index+1)函数将其划分为5×5个子图,遍历第index+1个子图,分别将图像资源绘制到子图,通过set_title()设置每个子图的title显示内容。子图绘制结束后显示整个图片,并调用函数传入图片、标签、预测值等参数。

prediction=ss.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
 
def show_result(images,labels,prediction,index,num=10):   #绘制图形显示预测结果
  pic=plt.gcf()                      #获取当前图像
  pic.set_size_inches(10,12)               #设置图片大小
  for i in range(0,num):
    sub_pic=plt.subplot(5,5,i+1)            #获取第i个子图
    #将第index个images信息显示到子图上
    sub_pic.imshow(np.reshape(images[index],(28,28)),cmap='binary') 
    title="label:"+str(np.argmax(labels[index]))    #设置子图的title内容
    if len(prediction)>0:
      title+=",predict:"+str(prediction[index])
      
    sub_pic.set_title(title,fontsize=10)
    sub_pic.set_xticks([])               #设置x、y坐标轴不显示
    sub_pic.set_yticks([])
    index+=1
  plt.show()
show_result(mnist.test.images,mnist.test.labels,prediction,10)

运行结果如下,可以看到预测的结果大多准确

Python利用逻辑回归模型解决MNIST手写数字识别问题详解

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python定时检查启动某个exe程序适合检测exe是否挂了
Jan 21 Python
Python sys.path详细介绍
Oct 17 Python
浅析使用Python操作文件
Jul 31 Python
Python实现连接两个无规则列表后删除重复元素并升序排序的方法
Feb 05 Python
Python实现决策树C4.5算法的示例
May 30 Python
python实现自动登录
Sep 17 Python
简单了解Django应用app及分布式路由
Jul 24 Python
python实现大量图片重命名
Mar 23 Python
python使用python-pptx删除ppt某页实例
Feb 14 Python
tensorflow常用函数API介绍
Apr 19 Python
Django 解决新建表删除后无法重新创建等问题
May 21 Python
解决pytorch 交叉熵损失输出为负数的问题
Jul 07 Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
Python selenium 自动化脚本打包成一个exe文件(推荐)
Jan 14 #Python
pytorch+lstm实现的pos示例
Jan 14 #Python
Python中sorted()排序与字母大小写的问题
Jan 14 #Python
Pytorch实现LSTM和GRU示例
Jan 14 #Python
Python生成词云的实现代码
Jan 14 #Python
You might like
PHP分页显示制作详细讲解
2006/12/05 PHP
PHP中func_get_args(),func_get_arg(),func_num_args()的区别
2013/09/30 PHP
php中eval函数的危害与正确禁用方法
2014/06/30 PHP
PHP中IP地址与整型数字互相转换详解
2014/08/20 PHP
怎么用javascript进行拖拽
2006/07/20 Javascript
ArrayList类(增强版)
2007/04/04 Javascript
JavaScript 语言的递归编程
2010/05/18 Javascript
返回上一页并自动刷新的JavaScript代码
2014/02/19 Javascript
JQuery EasyUI 数字格式化处理示例
2014/05/05 Javascript
使用requestAnimationFrame实现js动画性能好
2015/08/06 Javascript
js支持键盘控制的左右切换立体式图片轮播效果代码分享
2015/08/26 Javascript
AngularJS中的$watch(),$digest()和$apply()区分
2016/04/04 Javascript
JavaScript实现大图轮播效果
2017/01/11 Javascript
Vue项目中quill-editor带样式编辑器的使用方法
2017/08/08 Javascript
Vue实现typeahead组件功能(非常靠谱)
2017/08/26 Javascript
jQuery实现数字自动增加或者减少的动画效果示例
2018/12/11 jQuery
Vue 利用指令实现禁止反复发送请求的两种方法
2019/09/15 Javascript
vue与django集成打包的实现方法
2019/11/11 Javascript
JavaScript 监听组合按键思路及代码实现
2020/07/28 Javascript
Python去除、替换字符串空格的处理方法
2018/04/01 Python
Python OpenCV 使用滑动条来调整函数参数的方法
2019/07/08 Python
python错误调试及单元文档测试过程解析
2019/12/19 Python
Python内置数据类型list各方法的性能测试过程解析
2020/01/07 Python
python GUI库图形界面开发之PyQt5树形结构控件QTreeWidget详细使用方法与实例
2020/03/02 Python
Python Map 函数的使用
2020/08/28 Python
python字典与json转换的方法总结
2020/12/28 Python
HTML5 表单验证失败的提示语问题
2017/07/13 HTML / CSS
自我鉴定模板
2013/10/29 职场文书
语文教育专业推荐信范文
2013/11/25 职场文书
工程专业应届生求职信
2014/02/19 职场文书
初中班主任评语大全
2014/04/24 职场文书
2014大学生学生会工作总结
2014/12/19 职场文书
护林员个人总结
2015/03/04 职场文书
团日活动总结格式
2015/05/11 职场文书
千与千寻观后感
2015/06/04 职场文书
送给教师们,到底该如何写好教学反思?
2019/07/02 职场文书