基于TensorFlow的CNN实现Mnist手写数字识别


Posted in Python onJune 17, 2020

本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下

一、CNN模型结构

基于TensorFlow的CNN实现Mnist手写数字识别

  • 输入层:Mnist数据集(28*28)
  • 第一层卷积:感受视野5*5,步长为1,卷积核:32个
  • 第一层池化:池化视野2*2,步长为2
  • 第二层卷积:感受视野5*5,步长为1,卷积核:64个
  • 第二层池化:池化视野2*2,步长为2
  • 全连接层:设置1024个神经元
  • 输出层:0~9十个数字类别

二、代码实现

import tensorflow as tf
#Tensorflow提供了一个类来处理MNIST数据
from tensorflow.examples.tutorials.mnist import input_data
import time
 
#载入数据集
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
#设置批次的大小
batch_size=100
#计算一共有多少个批次
n_batch=mnist.train.num_examples//batch_size
 
#定义初始化权值函数
def weight_variable(shape):
 initial=tf.truncated_normal(shape,stddev=0.1)
 return tf.Variable(initial)
#定义初始化偏置函数
def bias_variable(shape):
 initial=tf.constant(0.1,shape=shape)
 return tf.Variable(initial)
#卷积层
def conv2d(input,filter):
 return tf.nn.conv2d(input,filter,strides=[1,1,1,1],padding='SAME')
#池化层
def max_pool_2x2(value):
 return tf.nn.max_pool(value,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
 
 
#输入层
#定义两个placeholder
x=tf.placeholder(tf.float32,[None,784]) #28*28
y=tf.placeholder(tf.float32,[None,10])
#改变x的格式转为4维的向量[batch,in_hight,in_width,in_channels]
x_image=tf.reshape(x,[-1,28,28,1])
 
 
#卷积、激励、池化操作
#初始化第一个卷积层的权值和偏置
W_conv1=weight_variable([5,5,1,32]) #5*5的采样窗口,32个卷积核从1个平面抽取特征
b_conv1=bias_variable([32]) #每一个卷积核一个偏置值
#把x_image和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1) #进行max_pooling 池化层
 
#初始化第二个卷积层的权值和偏置
W_conv2=weight_variable([5,5,32,64]) #5*5的采样窗口,64个卷积核从32个平面抽取特征
b_conv2=bias_variable([64])
#把第一个池化层结果和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2) #池化层
 
#28*28的图片第一次卷积后还是28*28,第一次池化后变为14*14
#第二次卷积后为14*14,第二次池化后变为了7*7
#经过上面操作后得到64张7*7的平面
 
 
#全连接层
#初始化第一个全连接层的权值
W_fc1=weight_variable([7*7*64,1024])#经过池化层后有7*7*64个神经元,全连接层有1024个神经元
b_fc1 = bias_variable([1024])#1024个节点
#把池化层2的输出扁平化为1维
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
#求第一个全连接层的输出
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
 
#keep_prob用来表示神经元的输出概率
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
 
#初始化第二个全连接层
W_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
 
#输出层
#计算输出
prediction=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)
 
#交叉熵代价函数
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用AdamOptimizer进行优化
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#结果存放在一个布尔列表中(argmax函数返回一维张量中最大的值所在的位置)
correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
#求准确率(tf.cast将布尔值转换为float型)
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
 
#创建会话
with tf.Session() as sess:
 start_time=time.clock()
 sess.run(tf.global_variables_initializer()) #初始化变量
 for epoch in range(21): #迭代21次(训练21次)
 for batch in range(n_batch):
 batch_xs,batch_ys=mnist.train.next_batch(batch_size)
 sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7}) #进行迭代训练
 #测试数据计算出准确率
 acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
 print('Iter'+str(epoch)+',Testing Accuracy='+str(acc))
 end_time=time.clock()
 print('Running time:%s Second'%(end_time-start_time)) #输出运行时间

运行结果:

基于TensorFlow的CNN实现Mnist手写数字识别

三、TensorFlow主要函数说明

1、卷积层

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)

(1)data_format:表示输入的格式,有两种分别为:“NHWC”和“NCHW”,默认为“NHWC”

(2)input:输入是一个4维格式的(图像)数据,数据的 shape 由 data_format 决定:当 data_format 为“NHWC”输入数据的shape表示为[batch, in_height, in_width, in_channels],分别表示训练时一个batch的图片数量、图片高度、 图片宽度、 图像通道数。当 data_format 为“NHWC”输入数据的shape表示为[batch, in_channels, in_height, in_width]

(3)filter:卷积核是一个4维格式的数据:shape表示为:[height,width,in_channels, out_channels],分别表示卷积核的高、宽、深度(与输入的in_channels应相同)、输出 feature map的个数(即卷积核的个数)。

(4)strides:表示步长:一个长度为4的一维列表,每个元素跟data_format互相对应,表示在data_format每一维上的移动步长。当输入的默认格式为:“NHWC”,则 strides = [batch , in_height , in_width, in_channels]。其中 batch 和 in_channels 要求一定为1,即只能在一个样本的一个通道上的特征图上进行移动,in_height , in_width表示卷积核在特征图的高度和宽度上移动的布长。

(5)padding:表示填充方式:“SAME”表示采用填充的方式,简单地理解为以0填充边缘,当stride为1时,输入和输出的维度相同;“VALID”表示采用不填充的方式,多余地进行丢弃。

对于卷积操作:

基于TensorFlow的CNN实现Mnist手写数字识别

2、池化层

#池化层:
#Max pooling:取“池化视野”矩阵中的最大值
tf.nn.max_pool( value, ksize,strides,padding,data_format='NHWC',name=None)
#Average pooling:取“池化视野”矩阵中的平均值
tf.nn.avg_pool(value, ksize,strides,padding,data_format='NHWC',name=None)

参数说明:

(1)value:表示池化的输入:一个4维格式的数据,数据的 shape 由 data_format 决定,默认情况下shape 为[batch, height, width, channels]

(2)ksize:表示池化窗口的大小:一个长度为4的一维列表,一般为[1, height, width, 1],因不想在batch和channels上做池化,则将其值设为1。

(3)其他参数与 tf.nn.cov2d 类型

对于池化操作:

基于TensorFlow的CNN实现Mnist手写数字识别

基于TensorFlow的CNN实现Mnist手写数字识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之关于循环的小伎俩
Oct 02 Python
python写日志封装类实例
Jun 28 Python
python中的编码知识整理汇总
Jan 26 Python
python django事务transaction源码分析详解
Mar 17 Python
Python 操作文件的基本方法总结
Aug 10 Python
Python3读取Excel数据存入MySQL的方法
May 04 Python
python基础学习之如何对元组各个元素进行命名详解
Jul 12 Python
Python格式化输出字符串方法小结【%与format】
Oct 29 Python
python实现的MySQL增删改查操作实例小结
Dec 19 Python
Python3爬楼梯算法示例
Mar 04 Python
python安装cx_Oracle和wxPython的方法
Sep 14 Python
python之基数排序的实现
Jul 26 Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
You might like
linux下安装php的memcached客户端
2014/08/03 PHP
PHP邮件群发机实现代码
2016/02/16 PHP
PHP登录验证码的实现与使用方法
2016/07/07 PHP
thinkphp中多表查询中防止数据重复的sql语句(必看)
2016/09/22 PHP
比较简单的异步加载JS文件的代码
2009/07/18 Javascript
Jquery中LigerUi的弹出编辑框(实现方法)
2013/07/09 Javascript
jquery实现两边飘浮可关闭的对联广告
2015/11/27 Javascript
jquery通过name属性取值的简单实现方法
2016/06/20 Javascript
深入解析桶排序算法及Node.js上JavaScript的代码实现
2016/07/06 Javascript
微信小程序 条件渲染详解
2016/10/09 Javascript
jQuery学习笔记——jqGrid的使用记录(实现分页、搜索功能)
2016/11/09 Javascript
JavaScript实现邮箱地址自动匹配功能代码
2016/11/28 Javascript
js判断一个字符串是以某个字符串开头的简单实例
2016/12/27 Javascript
Mac系统下Webstorm快捷键整理大全
2017/05/28 Javascript
Node.js使用orm2进行update操作时关联字段无法修改的解决方法
2017/06/13 Javascript
一个有意思的鼠标点击文字特效jquery代码
2017/09/23 jQuery
简易Vue评论框架的实现(父组件的实现)
2018/01/08 Javascript
React router动态加载组件之适配器模式的应用详解
2018/09/12 Javascript
基于js判断浏览器是否支持webGL
2020/04/18 Javascript
Vue 3.0 全家桶抢先体验
2020/04/28 Javascript
vue项目打包后提交到git上为什么没有dist这个文件的解决方法
2020/09/16 Javascript
Python if语句知识点用法总结
2018/06/10 Python
Python3.5装饰器典型案例分析
2019/04/30 Python
python中bytes和str类型的区别
2019/10/21 Python
Python类class参数self原理解析
2020/11/19 Python
关于HTML5你必须知道的28个新特性,新技巧以及新技术
2012/05/28 HTML / CSS
使用HTML5做的导航条详细步骤
2020/10/19 HTML / CSS
Bloomingdale’s阿联酋:选购奢华时尚、美容及更多
2020/09/22 全球购物
乡镇网格化管理实施方案
2014/03/23 职场文书
教师党员个人剖析材料
2014/09/29 职场文书
幼师求职自荐信
2015/03/26 职场文书
敬老院义诊活动总结
2015/05/07 职场文书
幼儿园教师读书笔记
2015/06/29 职场文书
《大禹治水》教学反思
2016/02/22 职场文书
React实现动效弹窗组件
2021/06/21 Javascript
element tree树形组件回显数据问题解决
2022/08/14 Javascript