tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过wxPython打开一个音频文件并播放的方法
Mar 25 Python
python中for语句简单遍历数据的方法
May 07 Python
python email smtplib模块发送邮件代码实例
Apr 26 Python
python for循环输入一个矩阵的实例
Nov 14 Python
Python提取特定时间段内数据的方法实例
Apr 01 Python
详解python之heapq模块及排序操作
Apr 04 Python
Python实现决策树并且使用Graphviz可视化的例子
Aug 09 Python
Python中的 sort 和 sorted的用法与区别
Aug 10 Python
如何基于python对接钉钉并获取access_token
Apr 21 Python
Python SMTP配置参数并发送邮件
Jun 16 Python
解决pycharm 格式报错tabs和space不一致问题
Feb 26 Python
基于Python实现流星雨效果的绘制
Mar 18 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
PHP页面间传递参数实例代码
2008/06/05 PHP
Laravel 5框架学习之Eloquent 关系
2015/04/09 PHP
一张表搞清楚php is_null、empty、isset的区别
2015/07/07 PHP
php把数组值转换成键的方法
2015/07/13 PHP
php array_values 返回数组的值实例详解
2016/11/17 PHP
php引用和拷贝的区别知识点总结
2019/09/23 PHP
Jquery下:nth-child(an+b)的使用注意
2011/05/28 Javascript
js形成页面的一种遮罩效果实例代码
2014/01/04 Javascript
原生JS实现拖拽图片效果
2020/08/27 Javascript
Bootstrap页面布局基础知识全面解析
2016/06/13 Javascript
简单实现bootstrap导航效果
2017/02/07 Javascript
jQuery EasyUI Draggable拖动组件
2017/03/01 Javascript
原生JS实现日历组件的示例代码
2017/09/22 Javascript
Node.js 多线程完全指南总结
2019/03/27 Javascript
[07:26]2015国际邀请赛第二日TOP10集锦
2015/08/06 DOTA
[01:04:14]OG vs Winstrike 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
python实现的简单文本类游戏实例
2015/04/28 Python
在Python中操作字符串之replace()方法的使用
2015/05/19 Python
python实现跨excel的工作表sheet之间的复制方法
2018/05/03 Python
Python读写docx文件的方法
2018/05/08 Python
对python抓取需要登录网站数据的方法详解
2018/05/21 Python
Python获取系统所有进程PID及进程名称的方法示例
2018/05/24 Python
python 中pyqt5 树节点点击实现多窗口切换问题
2019/07/04 Python
Python API 操作Hadoop hdfs详解
2020/06/06 Python
Python如何把字典写入到CSV文件的方法示例
2020/08/23 Python
Django3中的自定义用户模型实例详解
2020/08/23 Python
纯CSS3实现绘制各种图形实现代码详细整理
2012/12/26 HTML / CSS
CSS3 对过渡(transition)进行调速以及延时
2020/10/21 HTML / CSS
德国BA保镖药房韩文网:kr.ba.de
2017/09/04 全球购物
NYX Professional Makeup官方网站:专业彩妆和美容产品
2019/10/29 全球购物
一名女生的自荐信
2013/12/08 职场文书
大学军训感言
2014/01/10 职场文书
思想品德课教学反思
2014/02/10 职场文书
毕业生求职自荐信(2016最新版)
2016/01/28 职场文书
2016年综治宣传月活动宣传标语口号
2016/03/16 职场文书
vue/cli 配置动态代理无需重启服务的方法
2022/05/20 Vue.js