tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python正则表达式的使用范例详解
Aug 08 Python
初学Python实用技巧两则
Aug 29 Python
Python的for和break循环结构中使用else语句的技巧
May 24 Python
解决出现Incorrect integer value: '' for column 'id' at row 1的问题
Oct 29 Python
对python中使用requests模块参数编码的不同处理方法
May 18 Python
Python基于SMTP协议实现发送邮件功能详解
Aug 14 Python
python实现自动登录后台管理系统
Oct 18 Python
pytorch-RNN进行回归曲线预测方式
Jan 14 Python
python super用法及原理详解
Jan 20 Python
浅谈keras保存模型中的save()和save_weights()区别
May 21 Python
Python 排序最长英文单词链(列表中前一个单词末字母是下一个单词的首字母)
Dec 14 Python
关于python爬虫应用urllib库作用分析
Sep 04 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
php命令行用法入门实例教程
2014/10/27 PHP
PHP分页类集锦
2014/11/18 PHP
浅谈PHP中JSON数据操作
2015/07/01 PHP
实例讲解YII2中多表关联的使用方法
2017/07/21 PHP
JavaScript中void(0)的具体含义解释
2007/02/27 Javascript
ListBox实现上移,下移,左移,右移的简单实例
2014/02/13 Javascript
JS简单操作select和dropdownlist实例
2014/11/26 Javascript
JS获取iframe中longdesc属性的方法
2015/04/01 Javascript
使用jspdf生成pdf报表
2015/07/03 Javascript
基于bootstrap3和jquery的分页插件
2015/07/31 Javascript
谈一谈jQuery核心架构设计
2016/03/28 Javascript
微信小程序 欢迎界面开发的实例详解
2016/11/30 Javascript
Angular2平滑升级到Angular4的步骤详解
2017/03/29 Javascript
jQuery插件FusionCharts绘制的3D环饼图效果示例【附demo源码】
2017/04/02 jQuery
手把手教你把nodejs部署到linux上跑出hello world
2017/06/19 NodeJs
Vue多系统切换实现方案
2018/06/05 Javascript
javaScript 实现重复输出给定的字符串的常用方法小结
2020/02/20 Javascript
Python实现partial改变方法默认参数
2014/08/18 Python
跟老齐学Python之??碌某?? target=
2014/09/12 Python
Python的Flask框架应用调用Redis队列数据的方法
2016/06/06 Python
python用pandas数据加载、存储与文件格式的实例
2018/12/07 Python
运用Python的webbrowser实现定时打开特定网页
2019/02/21 Python
python解析xml简单示例
2019/06/21 Python
Python 异常处理Ⅳ过程图解
2019/10/18 Python
Django日志及中间件模块应用案例
2020/09/10 Python
Mio Skincare中文官网:肌肤和身体护理
2016/10/26 全球购物
Marmot土拨鼠官网:美国专业户外运动品牌
2018/01/11 全球购物
岗位职责的含义
2013/11/17 职场文书
简易离婚协议书范本2014
2014/10/15 职场文书
2014年采购部工作总结
2014/11/20 职场文书
2015年宣传工作总结
2015/04/08 职场文书
幽默导游词开场白
2015/05/29 职场文书
《卖火柴的小女孩》教学反思
2016/02/19 职场文书
教育教学工作反思
2016/02/24 职场文书
nginx实现发布静态资源的方法
2021/03/31 Servers
使用HttpSessionListener监听器实战
2022/03/17 Java/Android