tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
通过python+selenium3实现浏览器刷简书文章阅读量
Dec 26 Python
Python爬虫天气预报实例详解(小白入门)
Jan 24 Python
python操作oracle的完整教程分享
Jan 30 Python
python学生信息管理系统
Mar 13 Python
通过实例了解python property属性
Nov 01 Python
python GUI库图形界面开发之PyQt5多线程中信号与槽的详细使用方法与实例
Mar 08 Python
python实现ftp文件传输功能
Mar 20 Python
TensorFlow2.1.0最新版本安装详细教程
Apr 08 Python
Pycharm pyuic5实现将ui文件转为py文件,让UI界面成功显示
Apr 08 Python
python读取yaml文件后修改写入本地实例
Apr 27 Python
Python中常见的数制转换有哪些
May 27 Python
通用的Django注册功能模块实现方法
Feb 05 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
PHP中return 和 exit 、break和contiue 区别与用法
2012/04/09 PHP
php设计模式小结
2013/02/15 PHP
10 个经典PHP函数
2013/10/17 PHP
PHP中上传多个文件的表单设计例子
2014/11/19 PHP
基于linnux+phantomjs实现生成图片格式的网页快照
2015/04/15 PHP
JS实多级联动下拉菜单类,简单实现省市区联动菜单!
2007/05/03 Javascript
JQuery 学习笔记 选择器之一
2009/07/23 Javascript
基于JavaScript 下namespace 功能的简单分析
2013/07/05 Javascript
给html超链接设置事件不使用href来完成跳
2014/04/20 Javascript
js判断上传文件后缀名是否合法
2016/01/28 Javascript
javascript Promise简单学习使用方法小结
2016/05/17 Javascript
JavaScript中子对象访问父对象的方式详解
2016/09/01 Javascript
基于iscroll.js实现下拉刷新和上拉加载效果
2016/11/28 Javascript
js事件冒泡与事件捕获详解
2017/02/20 Javascript
js上传图片预览的实现方法
2017/05/09 Javascript
vuejs事件中心管理组件间的通信详解
2017/08/09 Javascript
详解Vue单元测试Karma+Mocha学习笔记
2018/01/31 Javascript
jquery ajaxfileuplod 上传文件 essyui laoding 效果【防止重复上传文件】
2018/05/26 jQuery
Angular2之二级路由详解
2018/08/31 Javascript
node.js的Express服务器基本使用教程
2019/01/09 Javascript
使用Vue中 v-for循环列表控制按钮隐藏显示功能
2019/04/23 Javascript
基于jQuery的时间戳与日期间的转化
2019/06/21 jQuery
微信小程序登录时如何获取input框中的内容
2019/12/04 Javascript
vue 页面回退mounted函数不执行的解决方案
2020/07/26 Javascript
Python读取YUV文件,并显示的方法
2018/12/04 Python
python实现扫描ip地址的小程序
2019/04/16 Python
Python 依赖库太多了该如何管理
2019/11/08 Python
css3实现画半圆弧线的示例代码
2017/11/06 HTML / CSS
基于IE10/HTML5 开发
2013/04/22 HTML / CSS
物理学专业求职信
2014/07/04 职场文书
卖车协议书范例
2014/09/16 职场文书
科长个人四风问题整改措施思想汇报
2014/10/13 职场文书
乡镇科协工作总结2015
2015/05/19 职场文书
2015年酒店年度工作总结
2015/05/23 职场文书
javascript之Object.assign()的痛点分析
2022/03/03 Javascript
JavaWeb实现显示mysql数据库数据
2022/03/19 Java/Android