python tensorflow基于cnn实现手写数字识别


Posted in Python onJanuary 01, 2018

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 加载数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 以交互式方式启动session
# 如果不使用交互式session,则在启动session前必须
# 构建整个计算图,才能启动该计算图
sess = tf.InteractiveSession()

"""构建计算图"""
# 通过占位符来为输入图像和目标输出类别创建节点
# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误
x = tf.placeholder("float", shape=[None, 784]) # 原始输入
y_ = tf.placeholder("float", shape=[None, 10]) # 目标值

# 为了不在建立模型的时候反复做初始化操作,
# 我们定义两个函数用于初始化
def weight_variable(shape):
 # 截尾正态分布,stddev是正态分布的标准偏差
 initial = tf.truncated_normal(shape=shape, stddev=0.1)
 return tf.Variable(initial)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

# 卷积核池化,步长为1,0边距
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
       strides=[1, 2, 2, 1], padding='SAME')

"""第一层卷积"""
# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积
# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一个输出通道都有一个偏置量
b_conv1 = bias_variable([32])

# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高
# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)
x_image = tf.reshape(x, [-1, 28, 28, 1])

# 第一层的卷积结果,使用Relu作为激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一层卷积后的池化结果
h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""
# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 将最后的池化层输出张量reshape成一维向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""
# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率
# 在训练的过程中启用dropout,在测试过程中关闭dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型预测输出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# 交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

# 模型训练,使用AdamOptimizer来做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 正确预测,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 将布尔值转化成浮点数,取平均值作为精确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 在session中先初始化变量才能在session中调用
sess.run(tf.global_variables_initializer())

# 迭代优化模型
for i in range(2000):
 # 每次取50个样本进行训练
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={
   x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout
  print("step %d, training accuracy %g" % (i, train_accuracy))
 train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
   x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

做了2000次迭代,在测试集上的识别精度能够到0.9772……

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python命令行参数解析模块optparse使用实例
Apr 13 Python
使用beaker让Facebook的Bottle框架支持session功能
Apr 23 Python
python 远程统计文件代码分享
May 14 Python
Python数据分析:手把手教你用Pandas生成可视化图表的教程
Dec 15 Python
python读取几个G的csv文件方法
Jan 07 Python
python random从集合中随机选择元素的方法
Jan 23 Python
Python函数定义及传参方式详解(4种)
Mar 18 Python
python 实现图像快速替换某种颜色
Jun 04 Python
Python中bisect的用法及示例详解
Jul 20 Python
如何利用pycharm进行代码更新比较
Nov 04 Python
python语言实现贪吃蛇游戏
Nov 13 Python
Python 全局空间和局部空间
Apr 06 Python
python+selenium实现163邮箱自动登陆的方法
Dec 31 #Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 #Python
利用python将图片转换成excel文档格式
Dec 30 #Python
书单|人生苦短,你还不用python!
Dec 29 #Python
python ansible服务及剧本编写
Dec 29 #Python
详解python 拆包可迭代数据如tuple, list
Dec 29 #Python
详解Python异常处理中的Finally else的功能
Dec 29 #Python
You might like
prototype 1.5 & scriptaculous 1.6.1 学习笔记
2006/09/07 Javascript
用js实现的一个Flash滚动轮换显示图片代码生成器
2007/03/14 Javascript
JavaScript中Math对象使用说明
2008/01/16 Javascript
js 实现日期灵活格式化的小例子
2013/07/14 Javascript
理解javascript回调函数
2014/12/28 Javascript
Javascript基于对象三大特性(封装性、继承性、多态性)
2016/01/04 Javascript
基于BootStrap Metronic开发框架经验小结【五】Bootstrap File Input文件上传插件的用法详解
2016/05/12 Javascript
json对象转为字符串,当做参数传递时加密解密的实现方法
2016/06/29 Javascript
JavaScript排序算法动画演示效果的实现方法
2016/10/18 Javascript
JavaScript输入框字数实时统计更新
2017/06/17 Javascript
JavaScript实现短信倒计时60s
2017/10/09 Javascript
vue和better-scroll实现列表左右联动效果详解
2019/04/29 Javascript
详解async/await 异步应用的常用场景
2019/05/13 Javascript
vue+layui实现select动态加载后台数据的例子
2019/09/20 Javascript
[44:30]完美世界DOTA2联赛PWL S2 GXR vs Magma 第一场 11.25
2020/11/26 DOTA
详谈python http长连接客户端
2017/06/12 Python
Python爬取数据并写入MySQL数据库的实例
2018/06/21 Python
python 图像平移和旋转的实例
2019/01/10 Python
python调用并链接MATLAB脚本详解
2019/07/05 Python
基于python实现生成指定大小txt文档
2020/07/20 Python
Python 爬虫的原理
2020/07/30 Python
CSS3弹性盒模型开发笔记(一)
2016/04/26 HTML / CSS
美国女士泳装店:Swimsuits For All
2017/03/02 全球购物
Java中会存在内存泄漏吗,请简单描述
2016/12/22 面试题
servlet面试题
2012/08/20 面试题
中专生求职自荐信范文
2013/12/22 职场文书
2014年应届大学生自我评价
2014/01/09 职场文书
标准的毕业生自荐信
2014/04/20 职场文书
英文推荐信格式范文
2014/05/09 职场文书
优秀家长事迹材料
2014/05/17 职场文书
2014年乡镇妇联工作总结
2014/12/02 职场文书
新闻通讯稿范文
2015/07/22 职场文书
劳动模范获奖感言
2015/07/31 职场文书
唱歌比赛拉拉队口号
2015/12/25 职场文书
查看nginx配置文件路径和资源文件路径的方法
2021/03/31 Servers
Java9新特性之Module模块化编程示例演绎
2022/03/16 Java/Android