python tensorflow基于cnn实现手写数字识别


Posted in Python onJanuary 01, 2018

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 加载数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 以交互式方式启动session
# 如果不使用交互式session,则在启动session前必须
# 构建整个计算图,才能启动该计算图
sess = tf.InteractiveSession()

"""构建计算图"""
# 通过占位符来为输入图像和目标输出类别创建节点
# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误
x = tf.placeholder("float", shape=[None, 784]) # 原始输入
y_ = tf.placeholder("float", shape=[None, 10]) # 目标值

# 为了不在建立模型的时候反复做初始化操作,
# 我们定义两个函数用于初始化
def weight_variable(shape):
 # 截尾正态分布,stddev是正态分布的标准偏差
 initial = tf.truncated_normal(shape=shape, stddev=0.1)
 return tf.Variable(initial)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

# 卷积核池化,步长为1,0边距
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
       strides=[1, 2, 2, 1], padding='SAME')

"""第一层卷积"""
# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积
# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一个输出通道都有一个偏置量
b_conv1 = bias_variable([32])

# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高
# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)
x_image = tf.reshape(x, [-1, 28, 28, 1])

# 第一层的卷积结果,使用Relu作为激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一层卷积后的池化结果
h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""
# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 将最后的池化层输出张量reshape成一维向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""
# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率
# 在训练的过程中启用dropout,在测试过程中关闭dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型预测输出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# 交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

# 模型训练,使用AdamOptimizer来做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 正确预测,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 将布尔值转化成浮点数,取平均值作为精确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 在session中先初始化变量才能在session中调用
sess.run(tf.global_variables_initializer())

# 迭代优化模型
for i in range(2000):
 # 每次取50个样本进行训练
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={
   x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout
  print("step %d, training accuracy %g" % (i, train_accuracy))
 train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
   x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

做了2000次迭代,在测试集上的识别精度能够到0.9772……

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python模拟登录12306的方法
Dec 30 Python
python实现挑选出来100以内的质数
Mar 24 Python
浅析Python中signal包的使用
Nov 13 Python
Python学习小技巧之利用字典的默认行为
May 20 Python
python实现稀疏矩阵示例代码
Jun 09 Python
13个最常用的Python深度学习库介绍
Oct 28 Python
python MysqlDb模块安装及其使用详解
Feb 23 Python
Python查找第n个子串的技巧分享
Jun 27 Python
Python动态语言与鸭子类型详解
Jul 01 Python
python按行读取文件并找出其中指定字符串
Aug 08 Python
详解Python图像处理库Pillow常用使用方法
Sep 02 Python
解决Python字典查找报Keyerror的问题
May 26 Python
python+selenium实现163邮箱自动登陆的方法
Dec 31 #Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 #Python
利用python将图片转换成excel文档格式
Dec 30 #Python
书单|人生苦短,你还不用python!
Dec 29 #Python
python ansible服务及剧本编写
Dec 29 #Python
详解python 拆包可迭代数据如tuple, list
Dec 29 #Python
详解Python异常处理中的Finally else的功能
Dec 29 #Python
You might like
php 时间计算问题小结
2009/01/04 PHP
PHP实现对站点内容外部链接的过滤方法
2014/09/10 PHP
Prototype Function对象 学习
2009/07/12 Javascript
在Javascript中 声明时用"var"与不用"var"的区别
2013/04/15 Javascript
在JavaScript里嵌入大量字符串常量的实现方法
2013/07/07 Javascript
js substring从右边获取指定长度字符串(示例代码)
2013/12/23 Javascript
js判断为空Null与字符串为空简写方法
2014/02/24 Javascript
使用javascript实现Iframe自适应高度
2014/12/24 Javascript
Json对象和字符串互相转换json数据拼接和JSON使用方式详细介绍(小结)
2016/10/25 Javascript
JavaScript的六种继承方式(推荐)
2017/06/26 Javascript
微信小程序实现长按删除图片的示例
2018/05/18 Javascript
最后说说Vue2 SSR 的 Cookies 问题
2018/05/25 Javascript
vue-router beforeEach跳转路由验证用户登录状态
2018/12/26 Javascript
js前端如何写一个精确的倒计时代码
2019/10/25 Javascript
[01:59]翻天覆地,因你而变,7.20版本地图更新速览
2018/11/24 DOTA
[54:18]DOTA2-DPC中国联赛 正赛 PSG.LGD vs LBZS BO3 第一场 1月22日
2021/03/11 DOTA
用Python的Tornado框架结合memcached页面改善博客性能
2015/04/24 Python
Python脚本实时处理log文件的方法
2016/11/21 Python
基于Python 装饰器装饰类中的方法实例
2018/04/21 Python
pygame游戏之旅 添加游戏暂停功能
2018/11/21 Python
对Python3 goto 语句的使用方法详解
2019/02/16 Python
python opencv将图片转为灰度图的方法示例
2019/07/31 Python
Python 一键获取百度网盘提取码的方法
2019/08/01 Python
python中使用np.delete()的实例方法
2021/02/01 Python
详解python3 GUI刷屏器(附源码)
2021/02/18 Python
Hotels.com中国区:好订网
2016/08/18 全球购物
英国打印机墨水和碳粉商店:Printerinks
2017/06/30 全球购物
正宗的日本零食和糖果订阅盒:Bokksu
2019/11/21 全球购物
What's the difference between deep copy and shallow copy? (深拷贝与浅拷贝有什么区别)
2015/11/10 面试题
建筑专业自荐信范文
2014/01/05 职场文书
预备党员党课思想汇报
2014/01/13 职场文书
倡议书格式范文
2014/04/14 职场文书
中学生的1000字检讨书
2014/10/11 职场文书
2015年度质量工作总结报告
2015/04/27 职场文书
2016秋季运动会前导词
2015/11/25 职场文书
Python基础知识学习之类的继承
2021/05/31 Python