tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用urllib模块开发的多线程豆瓣小站mp3下载器
Jan 16 Python
python通过pil为png图片填充上背景颜色的方法
Mar 17 Python
利用ctypes提高Python的执行速度
Sep 09 Python
Python用zip函数同时遍历多个迭代器示例详解
Nov 14 Python
python出现"IndentationError: unexpected indent"错误解决办法
Oct 15 Python
教你用Python创建微信聊天机器人
Mar 31 Python
python如何实现反向迭代
Mar 20 Python
libreoffice python 操作word及excel文档的方法
Jul 04 Python
python manage.py runserver流程解析
Nov 08 Python
django的autoreload机制实现
Jun 03 Python
经验丰富程序员才知道的8种高级Python技巧
Jul 27 Python
scrapy-splash简单使用详解
Feb 21 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
用PHP生成html分页列表的代码
2007/03/18 PHP
收集的PHP中与数组相关的函数
2007/03/22 PHP
php pack与unpack 摸板字符字符含义
2009/10/29 PHP
jQuery获取json后使用zy_tmpl生成下拉菜单
2015/03/27 PHP
深入理解PHP内核(二)之SAPI探究
2015/11/10 PHP
JavaScript 设计模式 安全沙箱模式
2010/09/24 Javascript
Javascript this 的一些学习总结
2012/08/31 Javascript
Javascript倒计时页面跳转实例小结
2013/09/11 Javascript
js中继承的几种用法总结(apply,call,prototype)
2013/12/26 Javascript
jQuery设置和移除文本框默认值的方法
2015/03/09 Javascript
用JavaScript实现对话框的教程
2015/06/04 Javascript
Windows下用PyCharm和Visual Studio开始Python编程
2015/10/26 Javascript
延时加载JavaScript代码提高速度
2015/12/27 Javascript
使用Node.js实现ORM的一种思路详解(图文)
2017/10/24 Javascript
VUE+elementui组件在table-cell单元格中绘制微型echarts图
2020/04/20 Javascript
在vue中created、mounted等方法使用小结
2020/07/21 Javascript
Python pass 语句使用示例
2014/03/11 Python
Python THREADING模块中的JOIN()方法深入理解
2015/02/18 Python
python字符串连接方法分析
2016/04/12 Python
python数据预处理之将类别数据转换为数值的方法
2017/07/05 Python
Python3远程监控程序的实现方法
2019/07/15 Python
django框架ModelForm组件用法详解
2019/12/11 Python
python中使用input()函数获取用户输入值方式
2020/05/03 Python
python利用paramiko实现交换机巡检的示例
2020/09/22 Python
Python列表嵌套常见坑点及解决方案
2020/09/30 Python
Pycharm-community-2020.2.3 社区版安装教程图文详解
2020/12/08 Python
简述Html5 IphoneX 适配方法
2018/02/08 HTML / CSS
入党积极分子思想汇报
2014/01/02 职场文书
放飞中国梦演讲稿
2014/04/23 职场文书
学生期末评语大全
2014/04/30 职场文书
社区优秀志愿者先进事迹
2014/05/09 职场文书
松材线虫病防治方案
2014/06/15 职场文书
小学数学国培研修日志
2015/11/13 职场文书
2016十一国庆节慰问信
2015/12/01 职场文书
oracle重置序列从0开始递增1
2022/02/28 Oracle
深入理解mysql事务隔离级别和存储引擎
2022/04/12 MySQL