tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基于mysql实现的简单队列以及跨进程锁实例详解
Jul 07 Python
python文件与目录操作实例详解
Feb 22 Python
深入解析Python中的线程同步方法
Jun 14 Python
Python列表list操作符实例分析【标准类型操作符、切片、连接字符、列表解析、重复操作等】
Jul 24 Python
python如何读写csv数据
Mar 21 Python
Python实现的求解最小公倍数算法示例
May 03 Python
python 实现的发送邮件模板【普通邮件、带附件、带图片邮件】
Jul 06 Python
python3实现的zip格式压缩文件夹操作示例
Aug 17 Python
numpy ndarray 按条件筛选数组,关联筛选的例子
Nov 26 Python
PyCharm 2020.2 安装详细教程
Sep 25 Python
python 实现aes256加密
Nov 27 Python
python 下载文件的几种方式分享
Apr 07 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
php连接数据库代码应用分析
2011/05/29 PHP
php冒泡排序、快速排序、快速查找、二维数组去重实例分享
2014/04/24 PHP
PHP 之 写时复制介绍(Copy On Write)
2014/05/13 PHP
symfony2.4的twig中date用法分析
2016/03/18 PHP
ThinkPHP实现简单登陆功能
2017/04/28 PHP
php上传后台无法收到数据解决方法
2019/10/28 PHP
ExtJS Grid使用SimpleStore、多选框的方法
2009/11/20 Javascript
拥抱模块化的JavaScript
2012/03/07 Javascript
NodeJS 模块开发及发布详解分享
2012/03/07 NodeJs
jQuery+css实现百度百科的页面导航效果
2014/12/16 Javascript
JavaScript实现彩虹文字效果的方法
2015/04/16 Javascript
jQuery插件实现静态HTML验证码校验
2015/11/06 Javascript
浅谈jquery中next与siblings的区别
2016/10/27 Javascript
ES6中箭头函数的定义与调用方式详解
2017/06/02 Javascript
layui 动态设置checbox 选中状态的例子
2019/09/02 Javascript
vue实现购物车结算功能
2020/06/18 Javascript
浅谈Vue开发人员的7个最好的VSCode扩展
2021/01/20 Vue.js
python通过字典dict判断指定键值是否存在的方法
2015/03/21 Python
Python使用剪切板的方法
2017/06/06 Python
K-近邻算法的python实现代码分享
2017/12/09 Python
Python操作MongoDB数据库的方法示例
2018/01/04 Python
详解python异步编程之asyncio(百万并发)
2018/07/07 Python
Python pip 安装与使用(安装、更新、删除)
2019/10/06 Python
详解HTML5 data-* 自定义属性
2018/01/24 HTML / CSS
美国求婚钻戒网站:Super Jeweler
2016/08/27 全球购物
倩碧香港官方网站:Clinique香港
2017/11/13 全球购物
电大自我鉴定
2013/10/27 职场文书
工厂厂长的职责
2013/12/12 职场文书
员工年终演讲稿
2014/01/03 职场文书
三八妇女节活动主持词
2014/03/17 职场文书
报纸媒体创意广告词
2014/03/17 职场文书
职务说明书范文
2014/05/07 职场文书
党的群众路线教育实践活动个人批评与自我批评
2014/10/16 职场文书
搬迁通知
2015/04/20 职场文书
生日寿星公答谢词
2015/09/29 职场文书
终止合同协议书范本
2016/03/22 职场文书