基于TensorFlow的CNN实现Mnist手写数字识别


Posted in Python onJune 17, 2020

本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下

一、CNN模型结构

基于TensorFlow的CNN实现Mnist手写数字识别

  • 输入层:Mnist数据集(28*28)
  • 第一层卷积:感受视野5*5,步长为1,卷积核:32个
  • 第一层池化:池化视野2*2,步长为2
  • 第二层卷积:感受视野5*5,步长为1,卷积核:64个
  • 第二层池化:池化视野2*2,步长为2
  • 全连接层:设置1024个神经元
  • 输出层:0~9十个数字类别

二、代码实现

import tensorflow as tf
#Tensorflow提供了一个类来处理MNIST数据
from tensorflow.examples.tutorials.mnist import input_data
import time
 
#载入数据集
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
#设置批次的大小
batch_size=100
#计算一共有多少个批次
n_batch=mnist.train.num_examples//batch_size
 
#定义初始化权值函数
def weight_variable(shape):
 initial=tf.truncated_normal(shape,stddev=0.1)
 return tf.Variable(initial)
#定义初始化偏置函数
def bias_variable(shape):
 initial=tf.constant(0.1,shape=shape)
 return tf.Variable(initial)
#卷积层
def conv2d(input,filter):
 return tf.nn.conv2d(input,filter,strides=[1,1,1,1],padding='SAME')
#池化层
def max_pool_2x2(value):
 return tf.nn.max_pool(value,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
 
 
#输入层
#定义两个placeholder
x=tf.placeholder(tf.float32,[None,784]) #28*28
y=tf.placeholder(tf.float32,[None,10])
#改变x的格式转为4维的向量[batch,in_hight,in_width,in_channels]
x_image=tf.reshape(x,[-1,28,28,1])
 
 
#卷积、激励、池化操作
#初始化第一个卷积层的权值和偏置
W_conv1=weight_variable([5,5,1,32]) #5*5的采样窗口,32个卷积核从1个平面抽取特征
b_conv1=bias_variable([32]) #每一个卷积核一个偏置值
#把x_image和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1) #进行max_pooling 池化层
 
#初始化第二个卷积层的权值和偏置
W_conv2=weight_variable([5,5,32,64]) #5*5的采样窗口,64个卷积核从32个平面抽取特征
b_conv2=bias_variable([64])
#把第一个池化层结果和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2) #池化层
 
#28*28的图片第一次卷积后还是28*28,第一次池化后变为14*14
#第二次卷积后为14*14,第二次池化后变为了7*7
#经过上面操作后得到64张7*7的平面
 
 
#全连接层
#初始化第一个全连接层的权值
W_fc1=weight_variable([7*7*64,1024])#经过池化层后有7*7*64个神经元,全连接层有1024个神经元
b_fc1 = bias_variable([1024])#1024个节点
#把池化层2的输出扁平化为1维
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
#求第一个全连接层的输出
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
 
#keep_prob用来表示神经元的输出概率
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
 
#初始化第二个全连接层
W_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
 
#输出层
#计算输出
prediction=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)
 
#交叉熵代价函数
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用AdamOptimizer进行优化
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#结果存放在一个布尔列表中(argmax函数返回一维张量中最大的值所在的位置)
correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
#求准确率(tf.cast将布尔值转换为float型)
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
 
#创建会话
with tf.Session() as sess:
 start_time=time.clock()
 sess.run(tf.global_variables_initializer()) #初始化变量
 for epoch in range(21): #迭代21次(训练21次)
 for batch in range(n_batch):
 batch_xs,batch_ys=mnist.train.next_batch(batch_size)
 sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7}) #进行迭代训练
 #测试数据计算出准确率
 acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
 print('Iter'+str(epoch)+',Testing Accuracy='+str(acc))
 end_time=time.clock()
 print('Running time:%s Second'%(end_time-start_time)) #输出运行时间

运行结果:

基于TensorFlow的CNN实现Mnist手写数字识别

三、TensorFlow主要函数说明

1、卷积层

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)

(1)data_format:表示输入的格式,有两种分别为:“NHWC”和“NCHW”,默认为“NHWC”

(2)input:输入是一个4维格式的(图像)数据,数据的 shape 由 data_format 决定:当 data_format 为“NHWC”输入数据的shape表示为[batch, in_height, in_width, in_channels],分别表示训练时一个batch的图片数量、图片高度、 图片宽度、 图像通道数。当 data_format 为“NHWC”输入数据的shape表示为[batch, in_channels, in_height, in_width]

(3)filter:卷积核是一个4维格式的数据:shape表示为:[height,width,in_channels, out_channels],分别表示卷积核的高、宽、深度(与输入的in_channels应相同)、输出 feature map的个数(即卷积核的个数)。

(4)strides:表示步长:一个长度为4的一维列表,每个元素跟data_format互相对应,表示在data_format每一维上的移动步长。当输入的默认格式为:“NHWC”,则 strides = [batch , in_height , in_width, in_channels]。其中 batch 和 in_channels 要求一定为1,即只能在一个样本的一个通道上的特征图上进行移动,in_height , in_width表示卷积核在特征图的高度和宽度上移动的布长。

(5)padding:表示填充方式:“SAME”表示采用填充的方式,简单地理解为以0填充边缘,当stride为1时,输入和输出的维度相同;“VALID”表示采用不填充的方式,多余地进行丢弃。

对于卷积操作:

基于TensorFlow的CNN实现Mnist手写数字识别

2、池化层

#池化层:
#Max pooling:取“池化视野”矩阵中的最大值
tf.nn.max_pool( value, ksize,strides,padding,data_format='NHWC',name=None)
#Average pooling:取“池化视野”矩阵中的平均值
tf.nn.avg_pool(value, ksize,strides,padding,data_format='NHWC',name=None)

参数说明:

(1)value:表示池化的输入:一个4维格式的数据,数据的 shape 由 data_format 决定,默认情况下shape 为[batch, height, width, channels]

(2)ksize:表示池化窗口的大小:一个长度为4的一维列表,一般为[1, height, width, 1],因不想在batch和channels上做池化,则将其值设为1。

(3)其他参数与 tf.nn.cov2d 类型

对于池化操作:

基于TensorFlow的CNN实现Mnist手写数字识别

基于TensorFlow的CNN实现Mnist手写数字识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现网站文件的全备份和差异备份
Nov 30 Python
python logging 日志轮转文件不删除问题的解决方法
Aug 02 Python
python字典键值对的添加和遍历方法
Sep 11 Python
对pandas里的loc并列条件索引的实例讲解
Nov 15 Python
在python中将字符串转为json对象并取值的方法
Dec 31 Python
查看python安装路径及pip安装的包列表及路径
Apr 03 Python
python装饰器常见使用方法分析
Jun 26 Python
python3读取图片并灰度化图片的四种方法(OpenCV、PIL.Image、TensorFlow方法)总结
Jul 04 Python
解决Django Static内容不能加载显示的问题
Jul 28 Python
python获取Linux发行版名称
Aug 30 Python
python实现文法左递归的消除方法
May 22 Python
Python趣味实例,实现一个简单的抽奖刮刮卡
Jul 18 Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
You might like
PHP与MySQL开发中页面出现乱码的一种解决方法
2007/07/29 PHP
PHP Undefined index报错的修复方法
2011/07/17 PHP
javascript编程起步(第六课)
2007/01/10 Javascript
javascript十个最常用的自定义函数(中文版)
2009/09/07 Javascript
jquery动态添加元素事件失效问题解决方法
2014/05/23 Javascript
JavaScript弹出窗口方法汇总
2014/08/12 Javascript
javascript常用代码段搜集
2014/12/04 Javascript
jQuery中data()方法用法实例
2014/12/27 Javascript
移动设备web开发首选框架:zeptojs介绍
2015/01/29 Javascript
JS获取下拉框显示值和判断单选按钮的方法
2015/07/09 Javascript
Jquery+Ajax+PHP+MySQL实现分类列表管理(上)
2015/10/28 Javascript
javascript实现网页端解压并查看zip文件
2015/12/15 Javascript
JS正则截取两个字符串之间及字符串前后内容的方法
2017/01/06 Javascript
Bootstrap表单控件使用方法详解
2017/01/11 Javascript
微信小程序wx.previewImage预览图片实例详解
2017/12/07 Javascript
10个经典的网页鼠标特效代码
2018/01/09 Javascript
vue基础知识--axios合并请求和slot
2020/06/04 Javascript
基于Echarts图表在div动态切换时不显示的解决方式
2020/07/20 Javascript
微信小程序input抖动问题的修复方法
2021/03/03 Javascript
python传递参数方式小结
2015/04/17 Python
python的多重继承的理解
2017/08/06 Python
python数字图像处理实现直方图与均衡化
2018/05/04 Python
解决python写入mysql中datetime类型遇到的问题
2018/06/21 Python
手把手教你如何安装Pycharm(详细图文教程)
2018/11/28 Python
使用tqdm显示Python代码执行进度功能
2019/12/08 Python
使用Keras构造简单的CNN网络实例
2020/06/29 Python
CSS3 制作旋转的大风车(充满童年回忆)
2013/01/30 HTML / CSS
英国手工制作的现代与经典的沙发和床:Love Your Home
2020/09/26 全球购物
大学生的网络创业计划书
2013/12/26 职场文书
学校师德师风整改措施
2014/10/27 职场文书
2014年工程工作总结
2014/11/25 职场文书
庆祝教师节主题班会
2015/08/17 职场文书
优质服务标语口号
2015/12/26 职场文书
小学五年级(说明文3篇)
2019/08/13 职场文书
php远程请求CURL案例(爬虫、保存登录状态)
2021/04/01 PHP
【海涛七七解说】DCG第二周:DK VS 天禄
2022/04/01 DOTA