解析Tensorflow之MNIST的使用


Posted in Python onJune 30, 2020

要说2017年什么技术最火爆,无疑是google领衔的深度学习开源框架Tensorflow。本文简述一下深度学习的入门例子MNIST。

深度学习简单介绍

首先要简单区别几个概念:人工智能,机器学习,深度学习,神经网络。这几个词应该是出现的最为频繁的,但是他们有什么区别呢?

人工智能:人类通过直觉可以解决的问题,如:自然语言理解,图像识别,语音识别等,计算机很难解决,而人工智能就是要解决这类问题。

机器学习:如果一个任务可以在任务T上,随着经验E的增加,效果P也随之增加,那么就认为这个程序可以从经验中学习。

深度学习:其核心就是自动将简单的特征组合成更加复杂的特征,并用这些特征解决问题。

神经网络:最初是一个生物学的概念,一般是指大脑神经元,触点,细胞等组成的网络,用于产生意识,帮助生物思考和行动,后来人工智能受神经网络的启发,发展出了人工神经网络。

来一张图就比较清楚了,如下图:

解析Tensorflow之MNIST的使用

MNIST解析

MNIST是深度学习的经典入门demo,他是由6万张训练图片和1万张测试图片构成的,每张图片都是28*28大小(如下图),而且都是黑白色构成(这里的黑色是一个0-1的浮点数,黑色越深表示数值越靠近1),这些图片是采集的不同的人手写从0到9的数字。TensorFlow将这个数据集和相关操作封装到了库中,下面我们来一步步解读深度学习MNIST的过程。

解析Tensorflow之MNIST的使用

上图就是4张MNIST图片。这些图片并不是传统意义上.jpg或者jpg格式的图片,因.jpg或者jpg的图片格式,会带有很多干扰信息(如:数据块,图片头,图片尾,长度等等),这些图片会被处理成很简易的二维数组,如图:

解析Tensorflow之MNIST的使用

可以看到,矩阵中有值的地方构成的图形,跟左边的图形很相似。之所以这样做,是为了让模型更简单清晰。特征更明显。

我们先看模型的代码以及如何训练模型:

mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
# x是特征值
 x = tf.placeholder(tf.float32, [None, 784])
# w表示每一个特征值(像素点)会影响结果的权重
 W = tf.Variable(tf.zeros([784, 10]))
 b = tf.Variable(tf.zeros([10]))
 y = tf.matmul(x, W) + b
# 是图片实际对应的值
 y_ = tf.placeholder(tf.float32, [None, 10])<br>
 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
 sess = tf.InteractiveSession()
 tf.global_variables_initializer().run()
 # mnist.train 训练数据
 for _ in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
 
 #取得y得最大概率对应的数组索引来和y_的数组索引对比,如果索引相同,则表示预测正确
 correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                    y_: mnist.test.labels}))

首先第一行是获取MNIST的数据集,我们逐一解释一下:

x(图片的特征值):这里使用了一个28*28=784列的数据来表示一个图片的构成,也就是说,每一个点都是这个图片的一个特征,这个其实比较好理解,因为每一个点都会对图片的样子和表达的含义有影响,只是影响的大小不同而已。至于为什么要将28*28的矩阵摊平成为一个1行784列的一维数组,我猜测可能是因为这样做会更加简单直观。

W(特征值对应的权重):这个值很重要,因为我们深度学习的过程,就是发现特征,经过一系列训练,从而得出每一个特征对结果影响的权重,我们训练,就是为了得到这个最佳权重值。

b(偏置量):是为了去线性话(我不是太清楚为什么需要这个值)

y(预测的结果):单个样本被预测出来是哪个数字的概率,比如:有可能结果是[ 1.07476616 -4.54194021 2.98073649 -7.42985344 3.29253793 1.967506178.59438515 -6.65950203 1.68721473 -0.9658531 ],则分别表示是0,1,2,3,4,5,6,7,8,9的概率,然后会取一个最大值来作为本次预测的结果,对于这个数组来说,结果是6(8.59438515)

y_(真实结果):来自MNIST的训练集,每一个图片所对应的真实值,如果是6,则表示为:[0 0 0 0 0 1 0 0 0]

再下面两行代码是损失函数(交叉熵)和梯度下降算法,通过不断的调整权重和偏置量的值,来逐步减小根据计算的预测结果和提供的真实结果之间的差异,以达到训练模型的目的。

算法确定以后便可以开始训练模型了,如下:

for _ in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

mnist.train.next_batch(100)是从训练集里一次提取100张图片数据来训练,然后循环1000次,以达到训练的目的。

之后的两行代码都有注释,不再累述。我们看最后一行代码:

print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                    y_: mnist.test.labels}))

mnist.test.images和mnist.test.labels是测试集,用来测试。accuracy是预测准确率。

当代码运行起来以后,我们发现,准确率大概在92%左右浮动。这个时候我们可能想看看到底是什么样的图片让预测不准。则添加如下代码:

for i in range(0, len(mnist.test.images)):
 result = sess.run(correct_prediction, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
 if not result:
  print('预测的值是:',sess.run(y, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])}))
  print('实际的值是:',sess.run(y_,feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])}))
  one_pic_arr = np.reshape(mnist.test.images[i], (28, 28))
  pic_matrix = np.matrix(one_pic_arr, dtype="float")
  plt.imshow(pic_matrix)
  pylab.show()
  break
 
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                   y_: mnist.test.labels}))

for循环内指明一旦result为false,就表示出现了预测值和实际值不符合的图片,然后我们把值和图片分别打印出来看看:

预测的值是: [[ 1.82234347 -4.87242508 2.63052988 -6.56350136 2.73666072 2.30682945 8.59051228 -7.20512581 1.45552373 -0.90134078]]

对应的是数字6。
实际的值是: [[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]

对应的是数字5。

我们再来看看图片是什么样子的:

解析Tensorflow之MNIST的使用

的确像5又像6。

总体来说,只有92%的准确率,还是比较低的,后续会解析一下比较适合识别图片的卷积神经网络,准确率可以达到99%以上。

一些体会与感想

我本人是一名iOS开发,也是迎着人工智能的浪潮开始一路学习,我觉得人工智能终将改变我们的生活,也会成为未来的一个热门学科。这一个多月的自学下来,我觉得最为困难的是克服自己的畏难情绪,因为我完全没有AI方面的任何经验,而且工作年限太久,线性代数,概率论等知识早已还给老师,所以在开始的时候,总是反反复复不停犹豫,纠结到底要不要把时间花费在研究深度学习上面。但是后来一想,假如我不学AI的东西,若干年后,AI发展越发成熟,到时候想学也会难以跟上步伐,而且,让电脑学会思考这本身就是一件很让人兴奋的事情,既然想学,有什么理由不去学呢?与大家共勉。

参考文章:

https://zhuanlan.zhihu.com/p/25482889

https://hit-scir.gitbooks.io/neural-networks-and-deep-learning-zh_cn/content/chap1/c1s0.html

到此这篇关于解析Tensorflow之MNIST的使用的文章就介绍到这了,更多相关Tensorflow MNIST内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Linux上安装Python的PIL和Pillow库处理图片的实例教程
Jun 23 Python
Python爬取京东的商品分类与链接
Aug 26 Python
Python更新数据库脚本两种方法及对比介绍
Jul 27 Python
django rest framework之请求与响应(详解)
Nov 06 Python
python opencv实现任意角度的透视变换实例代码
Jan 12 Python
Python使用装饰器进行django开发实例代码
Feb 06 Python
详解python中asyncio模块
Mar 03 Python
django利用request id便于定位及给日志加上request_id
Aug 26 Python
opencv-python 读取图像并转换颜色空间实例
Dec 09 Python
python中count函数简单用法
Jan 05 Python
在jupyter notebook中调用.ipynb文件方式
Apr 14 Python
基于Python的图像阈值化分割(迭代法)
Nov 20 Python
Tensorflow tensor 数学运算和逻辑运算方式
Jun 30 #Python
Python requests模块安装及使用教程图解
Jun 30 #Python
在Tensorflow中实现leakyRelu操作详解(高效)
Jun 30 #Python
TensorFlow-gpu和opencv安装详细教程
Jun 30 #Python
tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
Jun 30 #Python
python 最简单的实现适配器设计模式的示例
Jun 30 #Python
Tensorflow--取tensorf指定列的操作方式
Jun 30 #Python
You might like
php数据结构 算法(PHP描述) 简单选择排序 simple selection sort
2011/08/09 PHP
AJAX的跨域访问-两种有效的解决方法介绍
2013/06/22 PHP
PHP中的traits实现代码复用使用实例
2015/05/13 PHP
php获取错误信息的方法
2015/07/17 PHP
PHP实现重载的常用方法实例详解
2017/10/18 PHP
document 和 document.all 分别什么时候用
2006/06/22 Javascript
JavaScript高级程序设计 阅读笔记(四) ECMAScript中的类型转换
2012/02/27 Javascript
JS字符串累加Array不一定比字符串累加快(根据电脑配置)
2012/05/14 Javascript
不要使用jQuery触发原生事件的方法
2014/03/03 Javascript
JS的encodeURI和java的URLDecoder.decode使用介绍
2014/05/08 Javascript
Javascript冒泡排序算法详解
2014/12/03 Javascript
简介alert()与console.log()的不同
2015/08/26 Javascript
无法获取隐藏元素宽度和高度的解决方案
2017/03/07 Javascript
原生js实现验证码功能
2017/03/16 Javascript
详解在Angularjs中ui-sref和$state.go如何传递参数
2017/04/24 Javascript
JS控制鼠标拒绝点击某一按钮的实例
2017/12/29 Javascript
理解Koa2中的async&amp;await的用法
2018/02/05 Javascript
关于RxJS Subject的学习笔记
2018/12/05 Javascript
浅析vue插槽和作用域插槽的理解
2019/04/22 Javascript
TypeScript类型声明书写详解
2019/08/28 Javascript
antd-mobile ListView长列表的数据更新遇到的坑
2020/04/08 Javascript
Nodejs实现微信分账的示例代码
2021/01/19 NodeJs
使用PYTHON接收多播数据的代码
2012/03/01 Python
使用Python来编写HTTP服务器的超级指南
2016/02/18 Python
python3 shelve模块的详解
2017/07/08 Python
用uWSGI和Nginx部署Flask项目的方法示例
2019/05/05 Python
python飞机大战pygame碰撞检测实现方法分析
2019/12/17 Python
HTML5的结构和语义(4):语义性的内联元素
2008/10/17 HTML / CSS
阿玛尼美妆加拿大官方商城:Giorgio Armani Beauty加拿大
2017/10/24 全球购物
Java里面如何创建一个内部类的实例
2015/01/19 面试题
教师四风对照检查材料思想汇报
2014/09/17 职场文书
2014年少先队工作总结
2014/12/03 职场文书
2014年创先争优工作总结
2014/12/11 职场文书
辞职信模板(中英文版)
2015/02/27 职场文书
PHP实现两种排课方式
2021/06/26 PHP
草系十大最强宝可梦,纸片人上榜,榜首大家最熟悉
2022/03/18 日漫