tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程中time模块的一些关键用法解析
Jan 19 Python
Python快速从注释生成文档的方法
Dec 26 Python
Python2.7+pytesser实现简单验证码的识别方法
Dec 29 Python
numpy添加新的维度:newaxis的方法
Aug 02 Python
Python中的pathlib.Path为什么不继承str详解
Jun 23 Python
pandas基于时间序列的固定时间间隔求均值的方法
Jul 04 Python
Pandas之排序函数sort_values()的实现
Jul 09 Python
python GUI库图形界面开发之PyQt5开发环境配置与基础使用
Feb 25 Python
python opencv pytesseract 验证码识别的实现
Aug 28 Python
Pycharm添加虚拟解释器报错问题解决方案
Oct 13 Python
python 爬虫如何正确的使用cookie
Oct 27 Python
python中time包实例详解
Feb 02 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
《OVERLORD》第四季,终于等到你!
2020/03/02 日漫
PHP读取目录下所有文件的代码
2008/01/07 PHP
php设计模式 Visitor 访问者模式
2011/06/28 PHP
php警告Creating default object from empty value 问题的解决方法
2014/04/02 PHP
使用PHP函数scandir排除特定目录
2014/06/12 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(十二)
2014/06/25 PHP
php的4种常见运行方式
2015/03/20 PHP
PHP如何搭建百度Ueditor富文本编辑器
2018/09/21 PHP
js小技巧--自动隐藏红叉叉
2007/08/13 Javascript
js 页面刷新location.reload和location.replace的区别小结
2009/12/24 Javascript
javascript改变position值实现菜单滚动至顶部后固定
2013/01/18 Javascript
上传的js验证(图片/文件的扩展名)
2013/04/25 Javascript
jquery插件冲突(jquery.noconflict)解决方法分享
2014/03/20 Javascript
javascript简单实现滑动菜单效果的方法
2015/07/27 Javascript
Angular Js文件上传之form-data
2015/08/28 Javascript
js判断复选框是否选中及选中个数的实现代码
2016/05/30 Javascript
jQuery EasyUI Tab 选项卡问题小结
2016/08/16 Javascript
canvas实现流星雨的背景效果
2017/01/13 Javascript
JavaScript获取当前时间向前推三个月的方法示例
2017/02/04 Javascript
Vue.2.0.5过渡效果使用技巧
2017/03/16 Javascript
详解微信小程序Radio选中样式切换
2017/07/06 Javascript
js实现图片上传并预览功能
2018/08/06 Javascript
vue使用自定义指令实现拖拽
2021/01/29 Javascript
electron 安装,调试,打包的具体使用
2019/11/06 Javascript
python encode和decode的妙用
2009/09/02 Python
python使用mysqldb连接数据库操作方法示例详解
2013/12/03 Python
Python实现在matplotlib中两个坐标轴之间画一条直线光标的方法
2015/05/20 Python
python实现Flappy Bird源码
2018/12/24 Python
python如何从键盘获取输入实例
2020/06/18 Python
python 使用递归的方式实现语义图片分割功能
2020/07/16 Python
详解Python中openpyxl模块基本用法
2021/02/23 Python
html5响应式开发自动计算fontSize的方法
2020/01/13 HTML / CSS
幼儿园教研活动总结
2014/04/30 职场文书
2014年入党积极分子学习三中全会思想汇报
2014/09/13 职场文书
最感人的道歉情书
2015/05/12 职场文书
nginx 配置缓存
2022/05/11 Servers