tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python socket编程实例详解
May 27 Python
使用Python绘制图表大全总结
Feb 11 Python
Python 实现随机数详解及实例代码
Apr 15 Python
Python语言生成水仙花数代码示例
Dec 18 Python
Python多继承顺序实例分析
May 26 Python
python3 读取Excel表格中的数据
Oct 16 Python
如何利用python给图片添加半透明水印
Sep 06 Python
Python高级编程之继承问题详解(super与mro)
Nov 19 Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 Python
使用Python对Dicom文件进行读取与写入的实现
Apr 20 Python
Windows 下更改 jupyterlab 默认启动位置的教程详解
May 18 Python
Python之字符串的遍历的4种方式
Dec 08 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
我的论坛源代码(十)
2006/10/09 PHP
php 显示指定路径下的图片
2009/10/29 PHP
php网站被挂木马后的修复方法总结
2014/11/06 PHP
PHP实践教程之过滤、验证、转义与密码详解
2017/07/24 PHP
laravel框架使用FormRequest进行表单验证,验证异常返回JSON操作示例
2020/02/18 PHP
文本加密解密
2006/06/23 Javascript
jQuery Study Notes学习笔记 (二)
2010/08/04 Javascript
js判断iframe内的网页是否滚动到底部触发事件
2014/03/18 Javascript
jquery 显示*天*时*分*秒实现时间计时器
2014/05/07 Javascript
让angularjs支持浏览器自动填表
2014/11/10 Javascript
浅析javascript中函数声明和函数表达式的区别
2015/02/15 Javascript
介绍JavaScript的一个微型模版
2015/06/24 Javascript
js实现根据身份证号自动生成出生日期
2015/12/15 Javascript
Web打印解决方案之普通报表打印功能
2016/08/29 Javascript
详谈表单格式化插件jquery.serializeJSON
2017/06/23 jQuery
解决ant Design Search无法输入内容的问题
2020/10/29 Javascript
vue element-ul实现展开和收起功能的实例代码
2020/11/25 Vue.js
[01:52]2020年DOTA2 TI10夏季活动预告片
2020/07/15 DOTA
python 打印直角三角形,等边三角形,菱形,正方形的代码
2017/11/21 Python
Python 中的range(),以及列表切片方法
2018/07/02 Python
python实现向微信用户发送每日一句 python实现微信聊天机器人
2019/03/27 Python
python调用webservice接口的实现
2019/07/12 Python
Python时间序列缺失值的处理方法(日期缺失填充)
2019/08/11 Python
python中 _、__、__xx__()区别及使用场景
2020/06/30 Python
Python的collections模块真的很好用
2021/03/01 Python
利用Bootstrap实现漂亮简洁的CSS3价格表实例源码
2017/03/02 HTML / CSS
Boden美国官网:英伦原创时装品牌
2017/07/03 全球购物
教师专业理论水平的自我评价分享
2013/11/09 职场文书
应届生如何写自荐信
2014/01/05 职场文书
《散步》教学反思
2014/03/02 职场文书
煤矿安全协议书
2014/08/20 职场文书
官僚主义现象查摆问题整改措施
2014/10/04 职场文书
同学会演讲稿
2019/04/02 职场文书
毕业生自荐求职信书写的技巧
2019/08/26 职场文书
正确使用MySQL update语句
2021/05/26 MySQL
解析python中的jsonpath 提取器
2022/01/18 Python