tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
编写Python的web框架中的Model的教程
Apr 29 Python
python八大排序算法速度实例对比
Dec 06 Python
浅谈Scrapy框架普通反爬虫机制的应对策略
Dec 28 Python
Python实现的购物车功能示例
Feb 11 Python
Python 实现一行输入多个值的方法
Apr 21 Python
python/sympy求解矩阵方程的方法
Nov 08 Python
pygame游戏之旅 游戏中添加显示文字
Nov 20 Python
Python unittest框架操作实例解析
Apr 13 Python
使用tensorflow实现VGG网络,训练mnist数据集方式
May 26 Python
如何在Python对Excel进行读取
Jun 04 Python
浅谈对python中if、elif、else的误解
Aug 20 Python
Python类型转换的魔术方法详解
Dec 23 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
PHP session垃圾回收机制实例分析
2019/06/28 PHP
Javascript面向对象编程(三) 非构造函数的继承
2011/08/28 Javascript
javascript学习笔记(十五) js间歇调用和超时调用
2012/06/20 Javascript
浅谈javascript的原型继承
2012/07/25 Javascript
用jquery统计子菜单的条数示例代码
2013/10/18 Javascript
js冒泡、捕获事件及阻止冒泡方法详细总结
2014/05/08 Javascript
JavaScript判断变量是对象还是数组的方法
2014/08/28 Javascript
node.js中的http.createServer方法使用说明
2014/12/14 Javascript
jQuery ui实现动感的圆角渐变网站导航菜单效果代码
2015/08/26 Javascript
有关suggest快速删除后仍然出现下拉列表的bug问题
2016/12/02 Javascript
前端主流框架vue学习笔记第一篇
2017/07/26 Javascript
微信小程序实战篇之购物车的实现代码示例
2017/11/30 Javascript
Vue2.0 实现单选互斥的方法
2018/04/13 Javascript
微信小程序动态添加view组件的实例代码
2019/05/23 Javascript
bootstrap+spring boot实现面包屑导航功能(前端代码)
2019/10/09 Javascript
JS函数基本定义与用法示例
2020/01/15 Javascript
vue.config.js中配置Vue的路径别名的方法
2020/02/11 Javascript
js点击事件的执行过程实例分析【冒泡与捕获】
2020/04/11 Javascript
用Python编写一个国际象棋AI程序
2014/11/28 Python
浅析Python中的序列化存储的方法
2015/04/28 Python
python中global用法实例分析
2015/04/30 Python
Python 基础教程之包和类的用法
2017/02/23 Python
Python实现定制自动化业务流量报表周报功能【XlsxWriter模块】
2019/03/11 Python
python线程的几种创建方式详解
2019/08/29 Python
Windows10下 python3.7 安装 facenet的教程
2019/09/10 Python
python如何快速拼接字符串
2020/10/28 Python
Html5移动端获奖无缝滚动动画实现示例
2018/06/25 HTML / CSS
Lacoste澳大利亚官网:服装、鞋类及配饰
2018/11/14 全球购物
电气工程及其自动化学生实习自我鉴定
2013/09/19 职场文书
退学证明范本3篇
2014/10/29 职场文书
音乐教师个人总结
2015/02/06 职场文书
初中运动会前导词
2015/07/20 职场文书
Python字典和列表性能之间的比较
2021/06/07 Python
Python移位密码、仿射变换解密实例代码
2021/06/27 Python
mysql脏页是什么
2021/07/26 MySQL
Python学习之异常中的finally使用详解
2022/03/16 Python