tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python标准库os.path包、glob包使用实例
Nov 25 Python
Python发送email的3种方法
Apr 28 Python
Python中使用装饰器时需要注意的一些问题
May 11 Python
详解Python中break语句的用法
May 14 Python
纯python实现机器学习之kNN算法示例
Mar 01 Python
详解pandas安装若干异常及解决方案总结
Jan 10 Python
解决Python selenium get页面很慢时的问题
Jan 30 Python
Python 使用list和tuple+条件判断详解
Jul 30 Python
python使用多线程编写tcp客户端程序
Sep 02 Python
Python namedtuple命名元组实现过程解析
Jan 08 Python
pycharm解决关闭flask后依旧可以访问服务的问题
Apr 03 Python
python实现批处理文件
Jul 28 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
解析php5配置使用pdo
2013/07/03 PHP
PHP中array_slice函数用法实例详解
2014/11/25 PHP
PHP实现生成唯一会员卡号
2015/08/24 PHP
PHP实现从上往下打印二叉树的方法
2018/01/18 PHP
PHP实现登录验证码校验功能
2018/05/17 PHP
php ActiveMQ的安装与使用方法图文教程
2020/02/23 PHP
javaScript 简单验证代码(用户名,密码,邮箱)
2009/09/28 Javascript
一些javascript一些题目的解析
2010/12/25 Javascript
JSON语法五大要素图文介绍
2012/12/04 Javascript
对于this和$(this)的个人理解
2013/09/08 Javascript
基于jquery实现等比缩放图片
2014/12/03 Javascript
基于jQuery实现歌词滚动版音乐播放器的代码
2016/09/17 Javascript
JavaScript实现树的遍历算法示例【广度优先与深度优先】
2017/10/26 Javascript
Vue父子模版传值及组件传值的三种方法
2017/11/27 Javascript
vue webpack打包后图片路径错误的完美解决方法
2018/12/07 Javascript
简单的React SSR服务器渲染实现
2018/12/11 Javascript
node.js中ws模块创建服务端和客户端,网页WebSocket客户端
2019/03/06 Javascript
webpack常用构建优化策略小结
2019/11/21 Javascript
[02:03]永远的信仰DOTA2 中国军团历届国际邀请赛回顾
2016/06/26 DOTA
Python内置函数Type()函数一个有趣的用法
2015/02/18 Python
在Python的Flask框架中实现单元测试的教程
2015/04/20 Python
Python设计模式编程中解释器模式的简单程序示例分享
2016/03/02 Python
Python初学者常见错误详解
2019/07/02 Python
django用户登录验证的完整示例代码
2019/07/21 Python
Python 图像对比度增强的几种方法(小结)
2019/09/25 Python
Python实现扫码工具的示例代码
2020/10/09 Python
基于Python实现天天酷跑功能
2021/01/06 Python
用HTML5制作视频拼图的教程
2015/05/13 HTML / CSS
白兰氏健康Mall:BRAND’S
2017/11/13 全球购物
在线购买世界上最好的酒:BoozeBud
2018/06/07 全球购物
中学生家长评语大全
2014/04/16 职场文书
公司领导班子对照材料
2014/08/18 职场文书
2015年保安个人工作总结
2015/04/02 职场文书
房地产置业顾问岗位职责
2015/04/11 职场文书
学校运动会加油词
2015/07/18 职场文书
图文详解Nginx版本平滑升级方案
2021/09/15 Servers