tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用pygame模块编写俄罗斯方块游戏的代码实例
Dec 08 Python
Django自定义过滤器定义与用法示例
Mar 22 Python
python把转列表为集合的方法
Jun 28 Python
Python编程实现tail-n查看日志文件的方法
Jul 08 Python
pygame实现贪吃蛇游戏(上)
Oct 29 Python
详解centos7+django+python3+mysql+阿里云部署项目全流程
Nov 15 Python
python检查目录文件权限并修改目录文件权限的操作
Mar 11 Python
django实现模板中的字符串文字和自动转义
Mar 31 Python
Python 线性回归分析以及评价指标详解
Apr 02 Python
Python 给下载文件显示进度条和下载时间的实现
Apr 02 Python
Pytorch十九种损失函数的使用详解
Apr 29 Python
Django如何与Ajax交互
Apr 29 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
php 强制下载文件实现代码
2013/10/28 PHP
PHP源码分析之变量的存储过程分解
2014/07/03 PHP
通过Mootools 1.2来操纵HTML DOM元素
2009/09/15 Javascript
基于jquery的3d效果实现代码
2011/03/23 Javascript
jquery focus(fn),blur(fn)方法实例代码
2011/12/16 Javascript
javascript中方便增删改cookie的一个类
2012/10/11 Javascript
javascript alert乱码的解决方法
2013/11/05 Javascript
JS实现方向键切换输入框焦点的方法
2015/08/19 Javascript
深入理解JavaScript中为什么string可以拥有方法
2016/05/24 Javascript
JQuery PHP图片在线裁剪实例
2020/07/27 Javascript
JS实现拖拽的方法分析
2016/12/20 Javascript
JavaScript对JSON数据进行排序和搜索
2017/07/24 Javascript
vue 里面使用axios 和封装的示例代码
2017/09/01 Javascript
js判断传入时间和当前时间大小实例(超简单)
2018/01/11 Javascript
新手简单了解vue
2019/05/29 Javascript
jquery使用echarts实现有向图可视化功能示例
2019/11/25 jQuery
JavaScript中的惰性载入函数及优势
2020/02/18 Javascript
原生js实现购物车
2020/09/23 Javascript
Python爬虫之正则表达式的使用教程详解
2018/10/25 Python
Python构建图像分类识别器的方法
2019/01/12 Python
Python实现DDos攻击实例详解
2019/02/02 Python
Python3安装psycopy2以及遇到问题解决方法
2019/07/03 Python
python自动生成model文件过程详解
2019/11/02 Python
Python爬取网站图片并保存的实现示例
2021/02/26 Python
CSS3 伪类选择器 nth-child()说明
2010/07/10 HTML / CSS
Myprotein加拿大官网:欧洲第一的运动营养品牌
2018/01/06 全球购物
武汉高蓝德国际.net机试
2016/06/24 面试题
营业员个人总结的自我评价
2013/10/25 职场文书
20年同学聚会感言
2014/02/03 职场文书
税务会计岗位职责
2014/02/18 职场文书
自我鉴定标准格式
2014/03/19 职场文书
应届生求职信
2014/05/31 职场文书
电子商务专业毕业生求职信
2014/06/12 职场文书
python实现图片九宫格分割的示例
2021/04/25 Python
JS ES6异步解决方案
2021/04/29 Javascript
Python趣味挑战之教你用pygame画进度条
2021/05/31 Python