tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中关于时间和日期函数的常用计算总结(time和datatime)
Mar 08 Python
python paramiko模块学习分享
Aug 23 Python
mac下给python3安装requests库和scrapy库的实例
Jun 13 Python
Python从使用线程到使用async/await的深入讲解
Sep 16 Python
pandas求两个表格不相交的集合方法
Dec 08 Python
python基于递归解决背包问题详解
Jul 03 Python
python调用并链接MATLAB脚本详解
Jul 05 Python
详解Python中的format格式化函数的使用方法
Nov 20 Python
Python解释器以及PyCharm的安装教程图文详解
Feb 26 Python
Django 自定义404 500等错误页面的实现
Mar 08 Python
Python代码执行时间测量模块timeit用法解析
Jul 01 Python
Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
Oct 15 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
模板引擎Smarty深入浅出介绍
2006/12/06 PHP
可拖动窗口,附带鼠标控制渐变透明,开启关闭功能
2006/06/26 Javascript
实现变速回到顶部的JavaScript代码
2011/05/09 Javascript
Javascript学习笔记之函数篇(五) : 构造函数
2014/11/23 Javascript
谷歌浏览器调试JavaScript小技巧
2014/12/29 Javascript
jQuery实现tab选项卡效果的方法
2015/07/08 Javascript
jQuery Validate初步体验(二)
2015/12/12 Javascript
window.onerror()的用法与实例分析
2016/01/27 Javascript
Bootstrap CSS布局之表单
2016/12/17 Javascript
Node.js使用gm拼装sprite图片
2017/07/04 Javascript
使用Webpack提高Vue.js应用的方式汇总(四种)
2017/07/10 Javascript
React-Native使用Mobx实现购物车功能
2017/09/14 Javascript
vue.js中使用echarts实现数据动态刷新功能
2019/04/16 Javascript
python中 chr unichr ord函数的实例详解
2017/08/06 Python
Python使用Pandas库实现MySQL数据库的读写
2019/07/06 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
2019/08/17 Python
flask框架渲染Jinja模板与传入模板变量操作详解
2020/01/25 Python
python如何通过twisted搭建socket服务
2020/02/03 Python
python opencv实现图片缺陷检测(讲解直方图以及相关系数对比法)
2020/04/07 Python
Python尾递归优化实现代码及原理详解
2020/10/09 Python
Django用户认证系统如何实现自定义
2020/11/12 Python
pycharm配置QtDesigner的超详细方法
2021/01/25 Python
ProBikeKit澳大利亚:自行车套件,跑步和铁人三项装备
2016/11/30 全球购物
Gweniss格温妮丝女包官网:英国纯手工制造潮流包包品牌
2018/02/07 全球购物
JAVA中的关键字有什么特点
2014/03/07 面试题
学生会竞选自荐信
2013/10/12 职场文书
中专毕业生自荐信
2013/11/16 职场文书
食堂个人先进事迹
2014/01/22 职场文书
销售员个人求职的自我评价
2014/02/10 职场文书
护理学专业求职信
2014/06/29 职场文书
党课培训心得体会
2014/09/02 职场文书
关于颐和园的导游词
2015/01/30 职场文书
特此通知格式
2015/04/27 职场文书
2016保送生自荐信范文
2016/01/29 职场文书
PostgreSQL解析URL的方法
2021/08/02 PostgreSQL
TypeScript实用技巧 Nominal Typing名义类型详解
2022/09/23 Javascript