基于TensorFlow的CNN实现Mnist手写数字识别


Posted in Python onJune 17, 2020

本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下

一、CNN模型结构

基于TensorFlow的CNN实现Mnist手写数字识别

  • 输入层:Mnist数据集(28*28)
  • 第一层卷积:感受视野5*5,步长为1,卷积核:32个
  • 第一层池化:池化视野2*2,步长为2
  • 第二层卷积:感受视野5*5,步长为1,卷积核:64个
  • 第二层池化:池化视野2*2,步长为2
  • 全连接层:设置1024个神经元
  • 输出层:0~9十个数字类别

二、代码实现

import tensorflow as tf
#Tensorflow提供了一个类来处理MNIST数据
from tensorflow.examples.tutorials.mnist import input_data
import time
 
#载入数据集
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
#设置批次的大小
batch_size=100
#计算一共有多少个批次
n_batch=mnist.train.num_examples//batch_size
 
#定义初始化权值函数
def weight_variable(shape):
 initial=tf.truncated_normal(shape,stddev=0.1)
 return tf.Variable(initial)
#定义初始化偏置函数
def bias_variable(shape):
 initial=tf.constant(0.1,shape=shape)
 return tf.Variable(initial)
#卷积层
def conv2d(input,filter):
 return tf.nn.conv2d(input,filter,strides=[1,1,1,1],padding='SAME')
#池化层
def max_pool_2x2(value):
 return tf.nn.max_pool(value,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
 
 
#输入层
#定义两个placeholder
x=tf.placeholder(tf.float32,[None,784]) #28*28
y=tf.placeholder(tf.float32,[None,10])
#改变x的格式转为4维的向量[batch,in_hight,in_width,in_channels]
x_image=tf.reshape(x,[-1,28,28,1])
 
 
#卷积、激励、池化操作
#初始化第一个卷积层的权值和偏置
W_conv1=weight_variable([5,5,1,32]) #5*5的采样窗口,32个卷积核从1个平面抽取特征
b_conv1=bias_variable([32]) #每一个卷积核一个偏置值
#把x_image和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1) #进行max_pooling 池化层
 
#初始化第二个卷积层的权值和偏置
W_conv2=weight_variable([5,5,32,64]) #5*5的采样窗口,64个卷积核从32个平面抽取特征
b_conv2=bias_variable([64])
#把第一个池化层结果和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2) #池化层
 
#28*28的图片第一次卷积后还是28*28,第一次池化后变为14*14
#第二次卷积后为14*14,第二次池化后变为了7*7
#经过上面操作后得到64张7*7的平面
 
 
#全连接层
#初始化第一个全连接层的权值
W_fc1=weight_variable([7*7*64,1024])#经过池化层后有7*7*64个神经元,全连接层有1024个神经元
b_fc1 = bias_variable([1024])#1024个节点
#把池化层2的输出扁平化为1维
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
#求第一个全连接层的输出
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
 
#keep_prob用来表示神经元的输出概率
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
 
#初始化第二个全连接层
W_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
 
#输出层
#计算输出
prediction=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)
 
#交叉熵代价函数
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用AdamOptimizer进行优化
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#结果存放在一个布尔列表中(argmax函数返回一维张量中最大的值所在的位置)
correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
#求准确率(tf.cast将布尔值转换为float型)
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
 
#创建会话
with tf.Session() as sess:
 start_time=time.clock()
 sess.run(tf.global_variables_initializer()) #初始化变量
 for epoch in range(21): #迭代21次(训练21次)
 for batch in range(n_batch):
 batch_xs,batch_ys=mnist.train.next_batch(batch_size)
 sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7}) #进行迭代训练
 #测试数据计算出准确率
 acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
 print('Iter'+str(epoch)+',Testing Accuracy='+str(acc))
 end_time=time.clock()
 print('Running time:%s Second'%(end_time-start_time)) #输出运行时间

运行结果:

基于TensorFlow的CNN实现Mnist手写数字识别

三、TensorFlow主要函数说明

1、卷积层

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)

(1)data_format:表示输入的格式,有两种分别为:“NHWC”和“NCHW”,默认为“NHWC”

(2)input:输入是一个4维格式的(图像)数据,数据的 shape 由 data_format 决定:当 data_format 为“NHWC”输入数据的shape表示为[batch, in_height, in_width, in_channels],分别表示训练时一个batch的图片数量、图片高度、 图片宽度、 图像通道数。当 data_format 为“NHWC”输入数据的shape表示为[batch, in_channels, in_height, in_width]

(3)filter:卷积核是一个4维格式的数据:shape表示为:[height,width,in_channels, out_channels],分别表示卷积核的高、宽、深度(与输入的in_channels应相同)、输出 feature map的个数(即卷积核的个数)。

(4)strides:表示步长:一个长度为4的一维列表,每个元素跟data_format互相对应,表示在data_format每一维上的移动步长。当输入的默认格式为:“NHWC”,则 strides = [batch , in_height , in_width, in_channels]。其中 batch 和 in_channels 要求一定为1,即只能在一个样本的一个通道上的特征图上进行移动,in_height , in_width表示卷积核在特征图的高度和宽度上移动的布长。

(5)padding:表示填充方式:“SAME”表示采用填充的方式,简单地理解为以0填充边缘,当stride为1时,输入和输出的维度相同;“VALID”表示采用不填充的方式,多余地进行丢弃。

对于卷积操作:

基于TensorFlow的CNN实现Mnist手写数字识别

2、池化层

#池化层:
#Max pooling:取“池化视野”矩阵中的最大值
tf.nn.max_pool( value, ksize,strides,padding,data_format='NHWC',name=None)
#Average pooling:取“池化视野”矩阵中的平均值
tf.nn.avg_pool(value, ksize,strides,padding,data_format='NHWC',name=None)

参数说明:

(1)value:表示池化的输入:一个4维格式的数据,数据的 shape 由 data_format 决定,默认情况下shape 为[batch, height, width, channels]

(2)ksize:表示池化窗口的大小:一个长度为4的一维列表,一般为[1, height, width, 1],因不想在batch和channels上做池化,则将其值设为1。

(3)其他参数与 tf.nn.cov2d 类型

对于池化操作:

基于TensorFlow的CNN实现Mnist手写数字识别

基于TensorFlow的CNN实现Mnist手写数字识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python聚类算法之基本K均值实例详解
Nov 20 Python
谈谈Python进行验证码识别的一些想法
Jan 25 Python
使用C#配合ArcGIS Engine进行地理信息系统开发
Feb 19 Python
VTK与Python实现机械臂三维模型可视化详解
Dec 13 Python
python3+PyQt5重新实现自定义数据拖放处理
Apr 19 Python
对Python3中列表乘以某一个数的示例详解
Jul 20 Python
解决在pycharm运行代码,调用CMD窗口的命令运行显示乱码问题
Aug 23 Python
淘宝秒杀python脚本 扫码登录版
Sep 19 Python
使用pytorch实现可视化中间层的结果
Dec 30 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 Python
python中uuid模块实例浅析
Dec 29 Python
详解Python+OpenCV绘制灰度直方图
Mar 22 Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
You might like
各种战术和打法的原创者
2020/03/04 星际争霸
PHP读取目录下所有文件的代码
2008/01/07 PHP
laravel-admin解决表单select联动时,编辑默认没选上的问题
2019/09/30 PHP
jquery中输入验证中一个不错的效果
2010/08/21 Javascript
用JavaScript实现动画效果的方法
2013/07/20 Javascript
jquery cookie的用法总结
2013/11/18 Javascript
javascripit实现密码强度检测代码分享
2013/12/12 Javascript
javascript在网页中实现读取剪贴板粘贴截图功能
2014/06/07 Javascript
常用的JavaScript模板引擎介绍
2015/02/28 Javascript
浅谈AngularJS中ng-class的使用方法
2016/11/11 Javascript
基于JavaScript实现滑动门效果
2017/03/16 Javascript
nodeJS实现简单网页爬虫功能的实例(分享)
2017/06/08 NodeJs
vue实现表格数据的增删改查
2017/07/10 Javascript
JavaScript闭包的简单应用
2017/09/01 Javascript
vue+socket.io+express+mongodb 实现简易多房间在线群聊示例
2017/10/21 Javascript
vue 项目地址去掉 #的方法
2018/10/20 Javascript
Vue 动态组件与 v-once 指令的实现
2019/02/12 Javascript
Openlayers显示地理位置坐标的方法
2020/09/28 Javascript
JavaScript缓动动画函数的封装方法
2020/11/25 Javascript
[49:41]NB vs NAVI Supermajor小组赛A组 BO3 第一场 6.2
2018/06/03 DOTA
详解python中requirements.txt的一切
2017/03/03 Python
详解python上传文件和字符到PHP服务器
2017/11/24 Python
浅谈python 中类属性共享的问题
2019/07/02 Python
ubuntu上安装python的实例方法
2019/09/30 Python
Python中使用gflags实例及原理解析
2019/12/13 Python
python 插入日期数据到Oracle实例
2020/03/02 Python
python pandas.DataFrame.loc函数使用详解
2020/03/26 Python
python 制作网站小说下载器
2021/02/20 Python
网络维护中文求职信
2014/01/03 职场文书
爱我中华教学反思
2014/04/28 职场文书
副护士长竞聘演讲稿
2014/04/30 职场文书
2016年少先队活动总结
2016/04/06 职场文书
[有人@你]你有一封绿色倡议书,请查收!
2019/07/18 职场文书
html form表单基础入门案例讲解
2021/07/15 HTML / CSS
Win10 heic文件怎么打开 ? Win10 heic文件打开教程
2022/04/06 数码科技
Redis过期数据是否会被立马删除
2022/07/23 Redis