tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python实现递归版汉诺塔示例(汉诺塔递归算法)
Apr 08 Python
pip安装时ReadTimeoutError的解决方法
Jun 12 Python
基于Django框架利用Ajax实现点赞功能实例代码
Aug 19 Python
python单例模式获取IP代理的方法详解
Sep 13 Python
Python中flatten( )函数及函数用法详解
Nov 02 Python
pycharm配置pyqt5-tools开发环境的方法步骤
Feb 11 Python
Python爬虫学习之翻译小程序
Jul 30 Python
使用pycharm在本地开发并实时同步到服务器
Aug 02 Python
Python面向对象魔法方法和单例模块代码实例
Mar 25 Python
Django自定义列表 models字段显示方式
Apr 03 Python
python opencv实现图像配准与比较
Feb 09 Python
Python实现Excel自动分组合并单元格
Feb 22 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
PHP生成静态页面详解
2006/12/05 PHP
php检测useragent版本示例
2014/03/24 PHP
PHP中的多行字符串传递给JavaScript的两种方法
2014/06/19 PHP
typecho插件编写教程(三):保存配置
2015/05/28 PHP
php根据用户语言跳转相应网页
2015/11/04 PHP
PHP rmdir()函数的用法总结
2019/07/02 PHP
php下的原生ajax请求用法实例分析
2020/02/28 PHP
PHP变量的作用范围实例讲解
2020/12/22 PHP
cnblogs中在闪存中屏蔽某人的实现代码
2010/11/14 Javascript
19个很有用的 JavaScript库推荐
2011/06/27 Javascript
js时间日期和毫秒的相互转换
2013/02/22 Javascript
jQuery实现页面滚动时层智能浮动定位实例探讨
2013/03/29 Javascript
AngularJS入门教程之学习环境搭建
2014/12/06 Javascript
jQuery实现表格颜色交替显示的方法
2015/03/09 Javascript
纯js实现瀑布流布局及ajax动态新增数据
2016/04/07 Javascript
jquery单击文字或图片内容放大并居中显示
2017/06/23 jQuery
Nodejs之http的表单提交
2017/07/07 NodeJs
纯js实现页面返回顶部的动画(超简单)
2017/08/10 Javascript
vue实现打印功能的两种方法
2018/09/07 Javascript
ES10 特性的完整指南小结
2019/03/04 Javascript
JS插件amCharts实现绘制柱形图默认显示数值功能示例
2019/11/26 Javascript
Vue初始化中的选项合并之initInternalComponent详解
2020/06/11 Javascript
[44:37]完美世界DOTA2联赛PWL S3 Forest vs access 第一场 12.11
2020/12/13 DOTA
python中map()函数的使用方法示例
2017/09/29 Python
pandas 数据结构之Series的使用方法
2019/06/21 Python
pycharm 更改创建文件默认路径的操作
2020/02/15 Python
关于box-sizing的全面理解
2016/07/28 HTML / CSS
size?丹麦官网:英国伦敦的球鞋精品店
2019/04/15 全球购物
生物制药自我鉴定
2014/01/25 职场文书
高中生班主任评语
2014/04/25 职场文书
求职信怎么写范文
2014/05/26 职场文书
汽车维修求职信
2014/06/15 职场文书
处级领导班子全部召开专题民主生活会情况汇报
2014/09/27 职场文书
行政介绍信范文
2015/05/04 职场文书
mybatis3中@SelectProvider传递参数方式
2021/08/04 Java/Android
Python制作春联的示例代码
2022/01/22 Python