python tensorflow基于cnn实现手写数字识别


Posted in Python onJanuary 01, 2018

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 加载数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 以交互式方式启动session
# 如果不使用交互式session,则在启动session前必须
# 构建整个计算图,才能启动该计算图
sess = tf.InteractiveSession()

"""构建计算图"""
# 通过占位符来为输入图像和目标输出类别创建节点
# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误
x = tf.placeholder("float", shape=[None, 784]) # 原始输入
y_ = tf.placeholder("float", shape=[None, 10]) # 目标值

# 为了不在建立模型的时候反复做初始化操作,
# 我们定义两个函数用于初始化
def weight_variable(shape):
 # 截尾正态分布,stddev是正态分布的标准偏差
 initial = tf.truncated_normal(shape=shape, stddev=0.1)
 return tf.Variable(initial)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

# 卷积核池化,步长为1,0边距
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
       strides=[1, 2, 2, 1], padding='SAME')

"""第一层卷积"""
# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积
# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一个输出通道都有一个偏置量
b_conv1 = bias_variable([32])

# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高
# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)
x_image = tf.reshape(x, [-1, 28, 28, 1])

# 第一层的卷积结果,使用Relu作为激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一层卷积后的池化结果
h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""
# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 将最后的池化层输出张量reshape成一维向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""
# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率
# 在训练的过程中启用dropout,在测试过程中关闭dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型预测输出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# 交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

# 模型训练,使用AdamOptimizer来做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 正确预测,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 将布尔值转化成浮点数,取平均值作为精确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 在session中先初始化变量才能在session中调用
sess.run(tf.global_variables_initializer())

# 迭代优化模型
for i in range(2000):
 # 每次取50个样本进行训练
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={
   x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout
  print("step %d, training accuracy %g" % (i, train_accuracy))
 train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
   x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

做了2000次迭代,在测试集上的识别精度能够到0.9772……

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 生成目录树及显示文件大小的代码
Jul 23 Python
Python中实现字符串类型与字典类型相互转换的方法
Aug 18 Python
跟老齐学Python之大话题小函数(1)
Oct 10 Python
浅谈Python 字符串格式化输出(format/printf)
Jul 21 Python
Python基于动态规划算法解决01背包问题实例
Dec 06 Python
tensorflow学习教程之文本分类详析
Aug 07 Python
Python实现数据结构线性链表(单链表)算法示例
May 04 Python
django Admin文档生成器使用详解
Jul 22 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 Python
Python如何实现强制数据类型转换
Nov 22 Python
Python 元组拆包示例(Tuple Unpacking)
Dec 24 Python
Python如何进行时间处理
Aug 06 Python
python+selenium实现163邮箱自动登陆的方法
Dec 31 #Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 #Python
利用python将图片转换成excel文档格式
Dec 30 #Python
书单|人生苦短,你还不用python!
Dec 29 #Python
python ansible服务及剧本编写
Dec 29 #Python
详解python 拆包可迭代数据如tuple, list
Dec 29 #Python
详解Python异常处理中的Finally else的功能
Dec 29 #Python
You might like
可快速识别放射性物质-国外大神教你diy一个开放式辐射探测器
2020/03/12 无线电
PHP5中虚函数的实现方法分享
2011/04/20 PHP
PHP目录与文件操作技巧总结(创建,删除,遍历,读写,修改等)
2016/09/11 PHP
Windows下php+mysql5.7配置教程
2017/05/16 PHP
js+css 实现遮罩居中弹出层(随浏览器窗口滚动条滚动)
2013/12/11 Javascript
js与运算符和或运算符的妙用
2014/02/14 Javascript
Js实现手机发送验证码时按钮延迟操作
2014/06/20 Javascript
jQuery实现表单提交时判断的方法
2014/12/13 Javascript
JavaScript获取一个范围内日期的方法
2015/04/24 Javascript
javascript获取网页各种高宽及位置的方法总结
2016/07/27 Javascript
JavaScript输出所选择起始与结束日期的方法
2017/07/12 Javascript
angular中ui calendar的一些使用心得(推荐)
2017/11/03 Javascript
Vue的elementUI实现自定义主题方法
2018/02/23 Javascript
vue之封装多个组件调用同一接口的案例
2020/08/11 Javascript
vue在App.vue文件中监听路由变化刷新页面操作
2020/08/14 Javascript
Vue 使用iframe引用html页面实现vue和html页面方法的调用操作
2020/11/16 Javascript
springboot+vue实现文件上传下载
2020/11/17 Vue.js
[02:03]风行者至宝清风环佩外观展示
2020/09/05 DOTA
python登陆asp网站页面的实现代码
2015/01/14 Python
pymssql数据库操作MSSQL2005实例分析
2015/05/25 Python
Python数据持久化shelve模块用法分析
2018/06/29 Python
python设计微型小说网站(基于Django+Bootstrap框架)
2019/07/08 Python
Python使用tkinter模块实现推箱子游戏
2019/10/08 Python
Django 自定义权限管理系统详解(通过中间件认证)
2020/03/11 Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
2020/06/29 Python
Python数据模型与Python对象模型的相关总结
2021/01/26 Python
移动端Html5中百度地图的点击事件
2019/01/31 HTML / CSS
AmazeUI中模态框的实现
2020/08/19 HTML / CSS
健康监测猫砂:Pretty Litter
2017/05/25 全球购物
入党积极分子思想汇报范文
2014/01/05 职场文书
银行优秀员工事迹材料
2014/05/29 职场文书
小学语文教研活动总结
2014/07/01 职场文书
2014年项目工作总结
2014/11/24 职场文书
神龙架导游词
2015/02/11 职场文书
指导教师推荐意见
2015/06/05 职场文书
springboot实现string转json json里面带数组
2022/06/16 Java/Android