python tensorflow基于cnn实现手写数字识别


Posted in Python onJanuary 01, 2018

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 加载数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 以交互式方式启动session
# 如果不使用交互式session,则在启动session前必须
# 构建整个计算图,才能启动该计算图
sess = tf.InteractiveSession()

"""构建计算图"""
# 通过占位符来为输入图像和目标输出类别创建节点
# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误
x = tf.placeholder("float", shape=[None, 784]) # 原始输入
y_ = tf.placeholder("float", shape=[None, 10]) # 目标值

# 为了不在建立模型的时候反复做初始化操作,
# 我们定义两个函数用于初始化
def weight_variable(shape):
 # 截尾正态分布,stddev是正态分布的标准偏差
 initial = tf.truncated_normal(shape=shape, stddev=0.1)
 return tf.Variable(initial)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

# 卷积核池化,步长为1,0边距
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
       strides=[1, 2, 2, 1], padding='SAME')

"""第一层卷积"""
# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积
# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一个输出通道都有一个偏置量
b_conv1 = bias_variable([32])

# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高
# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)
x_image = tf.reshape(x, [-1, 28, 28, 1])

# 第一层的卷积结果,使用Relu作为激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一层卷积后的池化结果
h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""
# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 将最后的池化层输出张量reshape成一维向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""
# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率
# 在训练的过程中启用dropout,在测试过程中关闭dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型预测输出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# 交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

# 模型训练,使用AdamOptimizer来做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 正确预测,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 将布尔值转化成浮点数,取平均值作为精确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 在session中先初始化变量才能在session中调用
sess.run(tf.global_variables_initializer())

# 迭代优化模型
for i in range(2000):
 # 每次取50个样本进行训练
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={
   x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout
  print("step %d, training accuracy %g" % (i, train_accuracy))
 train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
   x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

做了2000次迭代,在测试集上的识别精度能够到0.9772……

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python根据路径导入模块的方法
Sep 30 Python
python实现中文输出的两种方法
May 09 Python
Python提取网页中超链接的方法
Sep 18 Python
python接口自动化测试之接口数据依赖的实现方法
Apr 26 Python
Django中如何防范CSRF跨站点请求伪造攻击的实现
Apr 28 Python
PyQt5基本控件使用详解:单选按钮、复选框、下拉框
Aug 05 Python
python 协程 gevent原理与用法分析
Nov 22 Python
python dict乱码如何解决
Jun 07 Python
Python3爬虫关于识别点触点选验证码的实例讲解
Jul 30 Python
浅析pandas随机排列与随机抽样
Jan 22 Python
Python NumPy灰度图像的压缩原理讲解
Aug 04 Python
python实现学生信息管理系统(面向对象)
Jun 05 Python
python+selenium实现163邮箱自动登陆的方法
Dec 31 #Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 #Python
利用python将图片转换成excel文档格式
Dec 30 #Python
书单|人生苦短,你还不用python!
Dec 29 #Python
python ansible服务及剧本编写
Dec 29 #Python
详解python 拆包可迭代数据如tuple, list
Dec 29 #Python
详解Python异常处理中的Finally else的功能
Dec 29 #Python
You might like
使用无限生命期Session的方法
2006/10/09 PHP
PHP中使用php5-ffmpeg撷取视频图片实例
2015/01/07 PHP
php操作xml入门之xml基本介绍及xml标签元素
2015/01/23 PHP
PHP 实现人民币小写转换成大写的方法及大小写转换函数
2017/11/17 PHP
jQuery Ajax之$.get()方法和$.post()方法
2009/10/12 Javascript
百度留言本js 大家可以参考下
2009/10/13 Javascript
JS/jQuery实现默认显示部分文字点击按钮显示全部内容
2013/05/13 Javascript
jquery实现动态画圆
2014/12/04 Javascript
JavaScript实现函数返回多个值的方法
2015/06/09 Javascript
jQuery表格插件datatables用法详解
2020/11/23 Javascript
javascript检测flash插件是否被禁用的方法
2016/01/14 Javascript
javascript中FOREACH数组方法使用示例
2016/03/01 Javascript
创建一个类Person的简单实例
2016/05/17 Javascript
前端弹出对话框 js实现ajax交互
2016/09/09 Javascript
js浏览器滚动条卷去的高度scrolltop(实例讲解)
2017/07/07 Javascript
Node.js学习之查询字符串解析querystring详解
2017/09/28 Javascript
快速解决layui弹窗按enter键不停弹窗的问题
2019/09/18 Javascript
json解析大全 双引号、键值对不在一起的情况
2019/12/06 Javascript
全面解析JavaScript Module模式
2020/07/24 Javascript
vue实现轮播图帧率播放
2021/01/26 Vue.js
[45:06]完美世界DOTA2联赛PWL S2 Magma vs InkIce 第二场 11.28
2020/12/02 DOTA
使用Pyrex来扩展和加速Python程序的教程
2015/04/13 Python
和孩子一起学习python之变量命名规则
2018/05/27 Python
在Python中字典根据多项规则排序的方法
2019/01/21 Python
python super用法及原理详解
2020/01/20 Python
django迁移文件migrations的实现
2020/03/31 Python
Python字符串格式化常用手段及注意事项
2020/06/17 Python
Tensorflow全局设置可见GPU编号操作
2020/06/30 Python
python中scipy.stats产生随机数实例讲解
2021/02/19 Python
HTML5的结构和语义(5):内嵌媒体
2008/10/17 HTML / CSS
美国电力供应商店/电气批发商:USESI
2018/10/12 全球购物
英国赛车、汽车改装和摩托车零件购物网站:Demon Tweeks
2018/10/29 全球购物
华为C++笔试题
2014/08/05 面试题
中秋节晚会开场白
2015/05/29 职场文书
大学开学感言
2015/08/01 职场文书
Python的三个重要函数详解
2022/01/18 Python