tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python构造函数及解构函数介绍
Feb 26 Python
Python中比较特别的除法运算和幂运算介绍
Apr 05 Python
python操作ie登陆土豆网的方法
May 09 Python
Python 制作糗事百科爬虫实例
Sep 22 Python
Python交互环境下实现输入代码
Jun 22 Python
pyQt4实现俄罗斯方块游戏
Jun 26 Python
判断python对象是否可调用的三种方式及其区别详解
Jan 31 Python
Python Django中间件,中间件函数,全局异常处理操作示例
Nov 08 Python
python thrift 实现 单端口多服务的过程
Jun 08 Python
python爬虫爬取网页数据并解析数据
Sep 18 Python
在Windows下安装配置CPU版的PyTorch的方法
Apr 02 Python
解决pycharm安装scrapy DLL load failed:找不到指定的程序的问题
Jun 08 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
星际流派综述
2020/03/04 星际争霸
qq登录,新浪微博登录接口申请过程中遇到的问题
2014/07/22 PHP
php实现过滤表单提交中html标签的方法
2014/10/17 PHP
php操作(删除,提取,增加)zip文件方法详解
2015/03/12 PHP
jquery tools系列 expose 学习
2009/09/06 Javascript
Javascript学习笔记1 数据类型
2010/01/11 Javascript
js中判断文本框是否为空的两种方法
2011/07/31 Javascript
推荐8款jQuery轻量级树形Tree插件
2014/11/12 Javascript
javascript限制用户只能输汉字中文的方法
2014/11/20 Javascript
原生js结合html5制作小飞龙的简易跳球
2015/03/30 Javascript
JavaScript中String.prototype用法实例
2015/05/20 Javascript
通过js获取上传的图片信息(临时保存路径,名称,大小)然后通过ajax传递给后端的方法
2015/10/01 Javascript
数据结构中的各种排序方法小结(JS实现)
2016/07/23 Javascript
javascript实现根据汉字获取简拼
2016/09/25 Javascript
BootStrap 图片样式、辅助类样式和CSS组件的实例详解
2017/01/20 Javascript
基于JavaScript实现复选框的全选和取消全选
2017/02/09 Javascript
vue 虚拟dom的patch源码分析
2018/03/01 Javascript
element-ui 关于获取select 的label值方法
2018/08/24 Javascript
toString.call()通用的判断数据类型方法示例
2020/08/28 Javascript
如何在vue中使用video.js播放m3u8格式的视频
2021/02/01 Vue.js
Python中的数据对象持久化存储模块pickle的使用示例
2016/03/03 Python
qpython3 读取安卓lastpass Cookies
2016/06/19 Python
使用pandas对矢量化数据进行替换处理的方法
2018/04/11 Python
python中virtualenvwrapper安装与使用
2018/05/20 Python
Python学习笔记之图片人脸检测识别实例教程
2019/03/06 Python
基于Django实现日志记录报错信息
2019/12/17 Python
详解python对象之间的交互
2020/09/29 Python
Python数据模型与Python对象模型的相关总结
2021/01/26 Python
Pytorch实现WGAN用于动漫头像生成
2021/03/04 Python
Boom手表官网:瑞典手表品牌,设计你的手表
2019/03/11 全球购物
歌唱比赛主持词
2014/03/18 职场文书
绘画专业自荐信
2014/07/04 职场文书
2015年见习期工作总结
2014/12/12 职场文书
Oracle笔记
2021/04/05 Oracle
MySQL 分组查询的优化方法
2021/05/12 MySQL
html粘性页脚的具体使用
2022/01/18 HTML / CSS