基于TensorFlow的CNN实现Mnist手写数字识别


Posted in Python onJune 17, 2020

本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下

一、CNN模型结构

基于TensorFlow的CNN实现Mnist手写数字识别

  • 输入层:Mnist数据集(28*28)
  • 第一层卷积:感受视野5*5,步长为1,卷积核:32个
  • 第一层池化:池化视野2*2,步长为2
  • 第二层卷积:感受视野5*5,步长为1,卷积核:64个
  • 第二层池化:池化视野2*2,步长为2
  • 全连接层:设置1024个神经元
  • 输出层:0~9十个数字类别

二、代码实现

import tensorflow as tf
#Tensorflow提供了一个类来处理MNIST数据
from tensorflow.examples.tutorials.mnist import input_data
import time
 
#载入数据集
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
#设置批次的大小
batch_size=100
#计算一共有多少个批次
n_batch=mnist.train.num_examples//batch_size
 
#定义初始化权值函数
def weight_variable(shape):
 initial=tf.truncated_normal(shape,stddev=0.1)
 return tf.Variable(initial)
#定义初始化偏置函数
def bias_variable(shape):
 initial=tf.constant(0.1,shape=shape)
 return tf.Variable(initial)
#卷积层
def conv2d(input,filter):
 return tf.nn.conv2d(input,filter,strides=[1,1,1,1],padding='SAME')
#池化层
def max_pool_2x2(value):
 return tf.nn.max_pool(value,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
 
 
#输入层
#定义两个placeholder
x=tf.placeholder(tf.float32,[None,784]) #28*28
y=tf.placeholder(tf.float32,[None,10])
#改变x的格式转为4维的向量[batch,in_hight,in_width,in_channels]
x_image=tf.reshape(x,[-1,28,28,1])
 
 
#卷积、激励、池化操作
#初始化第一个卷积层的权值和偏置
W_conv1=weight_variable([5,5,1,32]) #5*5的采样窗口,32个卷积核从1个平面抽取特征
b_conv1=bias_variable([32]) #每一个卷积核一个偏置值
#把x_image和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1) #进行max_pooling 池化层
 
#初始化第二个卷积层的权值和偏置
W_conv2=weight_variable([5,5,32,64]) #5*5的采样窗口,64个卷积核从32个平面抽取特征
b_conv2=bias_variable([64])
#把第一个池化层结果和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2) #池化层
 
#28*28的图片第一次卷积后还是28*28,第一次池化后变为14*14
#第二次卷积后为14*14,第二次池化后变为了7*7
#经过上面操作后得到64张7*7的平面
 
 
#全连接层
#初始化第一个全连接层的权值
W_fc1=weight_variable([7*7*64,1024])#经过池化层后有7*7*64个神经元,全连接层有1024个神经元
b_fc1 = bias_variable([1024])#1024个节点
#把池化层2的输出扁平化为1维
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
#求第一个全连接层的输出
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
 
#keep_prob用来表示神经元的输出概率
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
 
#初始化第二个全连接层
W_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
 
#输出层
#计算输出
prediction=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)
 
#交叉熵代价函数
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用AdamOptimizer进行优化
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#结果存放在一个布尔列表中(argmax函数返回一维张量中最大的值所在的位置)
correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
#求准确率(tf.cast将布尔值转换为float型)
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
 
#创建会话
with tf.Session() as sess:
 start_time=time.clock()
 sess.run(tf.global_variables_initializer()) #初始化变量
 for epoch in range(21): #迭代21次(训练21次)
 for batch in range(n_batch):
 batch_xs,batch_ys=mnist.train.next_batch(batch_size)
 sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7}) #进行迭代训练
 #测试数据计算出准确率
 acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
 print('Iter'+str(epoch)+',Testing Accuracy='+str(acc))
 end_time=time.clock()
 print('Running time:%s Second'%(end_time-start_time)) #输出运行时间

运行结果:

基于TensorFlow的CNN实现Mnist手写数字识别

三、TensorFlow主要函数说明

1、卷积层

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)

(1)data_format:表示输入的格式,有两种分别为:“NHWC”和“NCHW”,默认为“NHWC”

(2)input:输入是一个4维格式的(图像)数据,数据的 shape 由 data_format 决定:当 data_format 为“NHWC”输入数据的shape表示为[batch, in_height, in_width, in_channels],分别表示训练时一个batch的图片数量、图片高度、 图片宽度、 图像通道数。当 data_format 为“NHWC”输入数据的shape表示为[batch, in_channels, in_height, in_width]

(3)filter:卷积核是一个4维格式的数据:shape表示为:[height,width,in_channels, out_channels],分别表示卷积核的高、宽、深度(与输入的in_channels应相同)、输出 feature map的个数(即卷积核的个数)。

(4)strides:表示步长:一个长度为4的一维列表,每个元素跟data_format互相对应,表示在data_format每一维上的移动步长。当输入的默认格式为:“NHWC”,则 strides = [batch , in_height , in_width, in_channels]。其中 batch 和 in_channels 要求一定为1,即只能在一个样本的一个通道上的特征图上进行移动,in_height , in_width表示卷积核在特征图的高度和宽度上移动的布长。

(5)padding:表示填充方式:“SAME”表示采用填充的方式,简单地理解为以0填充边缘,当stride为1时,输入和输出的维度相同;“VALID”表示采用不填充的方式,多余地进行丢弃。

对于卷积操作:

基于TensorFlow的CNN实现Mnist手写数字识别

2、池化层

#池化层:
#Max pooling:取“池化视野”矩阵中的最大值
tf.nn.max_pool( value, ksize,strides,padding,data_format='NHWC',name=None)
#Average pooling:取“池化视野”矩阵中的平均值
tf.nn.avg_pool(value, ksize,strides,padding,data_format='NHWC',name=None)

参数说明:

(1)value:表示池化的输入:一个4维格式的数据,数据的 shape 由 data_format 决定,默认情况下shape 为[batch, height, width, channels]

(2)ksize:表示池化窗口的大小:一个长度为4的一维列表,一般为[1, height, width, 1],因不想在batch和channels上做池化,则将其值设为1。

(3)其他参数与 tf.nn.cov2d 类型

对于池化操作:

基于TensorFlow的CNN实现Mnist手写数字识别

基于TensorFlow的CNN实现Mnist手写数字识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简化Python的Django框架代码的一些示例
Apr 20 Python
python 把数据 json格式输出的实例代码
Oct 31 Python
使用python生成目录树
Mar 29 Python
python调用opencv实现猫脸检测功能
Jan 15 Python
Python3.5面向对象程序设计之类的继承和多态详解
Apr 24 Python
更新pip3与pyttsx3文字语音转换的实现方法
Aug 08 Python
将python包发布到PyPI和制作whl文件方式
Dec 25 Python
从多个tfrecord文件中无限读取文件的例子
Feb 17 Python
对python中arange()和linspace()的区别说明
May 03 Python
Python实现SMTP邮件发送
Jun 16 Python
基于Keras中Conv1D和Conv2D的区别说明
Jun 19 Python
PyTorch预训练Bert模型的示例
Nov 17 Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
You might like
ubuntu10.04配置 nginx+php-fpm模式的详解
2013/06/03 PHP
详解使用php调用微信接口上传永久素材
2017/04/11 PHP
multiSteps 基于Jquery的多步骤滑动切换插件
2011/07/22 Javascript
jQuery实现div浮动层跟随页面滚动效果
2014/02/11 Javascript
jquery实现带缩略图的可定制高度画廊效果(5种)
2015/08/28 Javascript
微信小程序 两种滑动方式(横向滑动,竖向滑动)详细及实例代码
2017/01/13 Javascript
JavaScript运行原理分析
2018/02/09 Javascript
express 项目分层实践详解
2018/12/10 Javascript
ES6知识点整理之函数数组参数的默认值及其解构应用示例
2019/04/17 Javascript
Vue实现兄弟组件间的联动效果
2020/01/21 Javascript
ES6 proxy和reflect的使用方法与应用实例分析
2020/02/15 Javascript
[46:20]CHAOS vs Alliacne 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/16 DOTA
python 读写txt文件 json文件的实现方法
2016/10/22 Python
python如何修改装饰器中参数
2018/03/20 Python
Python使用type关键字创建类步骤详解
2019/07/23 Python
python应用文件读取与登录注册功能
2019/09/23 Python
Python树莓派学习笔记之UDP传输视频帧操作详解
2019/11/15 Python
利用python实现.dcm格式图像转为.jpg格式
2020/01/13 Python
基于Python快速处理PDF表格数据
2020/06/03 Python
OpenCV+python实现实时目标检测功能
2020/06/24 Python
Python3爬虫中pyspider的安装步骤
2020/07/29 Python
HTML5标签使用方法详解
2015/11/27 HTML / CSS
埃弗顿足球俱乐部官方网上商店:Everton Direct
2018/01/13 全球购物
西班牙土拨鼠床垫公司,感觉在云端:Marmota
2019/03/18 全球购物
Keds加拿大官网:购买帆布运动鞋和皮鞋
2019/09/26 全球购物
幼儿园安全检查制度
2014/01/30 职场文书
室内拓展活动方案
2014/02/13 职场文书
企业文明单位申报材料
2014/05/16 职场文书
运动会广播稿150字(9篇)
2014/09/20 职场文书
群众路线剖析材料(四风问题)
2014/10/08 职场文书
2015年计生工作总结范文
2015/04/24 职场文书
证劵公司反洗钱宣传活动总结
2015/05/08 职场文书
工作态度怎么写
2015/06/25 职场文书
Golang实现AES对称加密的过程详解
2021/05/20 Golang
教你怎么用python selenium实现自动化测试
2021/05/27 Python
你真的会用Mysql的explain吗
2022/03/31 MySQL