python tensorflow基于cnn实现手写数字识别


Posted in Python onJanuary 01, 2018

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 加载数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 以交互式方式启动session
# 如果不使用交互式session,则在启动session前必须
# 构建整个计算图,才能启动该计算图
sess = tf.InteractiveSession()

"""构建计算图"""
# 通过占位符来为输入图像和目标输出类别创建节点
# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误
x = tf.placeholder("float", shape=[None, 784]) # 原始输入
y_ = tf.placeholder("float", shape=[None, 10]) # 目标值

# 为了不在建立模型的时候反复做初始化操作,
# 我们定义两个函数用于初始化
def weight_variable(shape):
 # 截尾正态分布,stddev是正态分布的标准偏差
 initial = tf.truncated_normal(shape=shape, stddev=0.1)
 return tf.Variable(initial)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

# 卷积核池化,步长为1,0边距
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
       strides=[1, 2, 2, 1], padding='SAME')

"""第一层卷积"""
# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积
# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一个输出通道都有一个偏置量
b_conv1 = bias_variable([32])

# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高
# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)
x_image = tf.reshape(x, [-1, 28, 28, 1])

# 第一层的卷积结果,使用Relu作为激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一层卷积后的池化结果
h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""
# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 将最后的池化层输出张量reshape成一维向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""
# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率
# 在训练的过程中启用dropout,在测试过程中关闭dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型预测输出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# 交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

# 模型训练,使用AdamOptimizer来做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 正确预测,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 将布尔值转化成浮点数,取平均值作为精确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 在session中先初始化变量才能在session中调用
sess.run(tf.global_variables_initializer())

# 迭代优化模型
for i in range(2000):
 # 每次取50个样本进行训练
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={
   x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout
  print("step %d, training accuracy %g" % (i, train_accuracy))
 train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
   x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

做了2000次迭代,在测试集上的识别精度能够到0.9772……

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
linux 下实现python多版本安装实践
Nov 18 Python
Python中的日期时间处理详解
Nov 17 Python
Python简单计算数组元素平均值的方法示例
Dec 26 Python
基于python中theano库的线性回归
Aug 31 Python
Python连接Mssql基础教程之Python库pymssql
Sep 16 Python
python的依赖管理的实现
May 14 Python
Python OpenCV调用摄像头检测人脸并截图
Aug 20 Python
Python如何调用外部系统命令
Aug 07 Python
Python 解析简单的XML数据
Jul 24 Python
用python批量移动文件
Jan 14 Python
Python实现位图分割的效果
Nov 20 Python
利用Python实现模拟登录知乎
May 25 Python
python+selenium实现163邮箱自动登陆的方法
Dec 31 #Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 #Python
利用python将图片转换成excel文档格式
Dec 30 #Python
书单|人生苦短,你还不用python!
Dec 29 #Python
python ansible服务及剧本编写
Dec 29 #Python
详解python 拆包可迭代数据如tuple, list
Dec 29 #Python
详解Python异常处理中的Finally else的功能
Dec 29 #Python
You might like
咖啡历史、消费和行业趋势
2021/03/03 咖啡文化
PHP 地址栏信息的获取代码
2009/01/07 PHP
用来解析.htpasswd文件的PHP类
2012/09/05 PHP
php异步多线程swoole用法实例
2014/11/14 PHP
php技巧小结【推荐】
2017/01/19 PHP
php实现微信支付之现金红包
2018/05/30 PHP
PHP自定义递归函数实现数组转JSON功能【支持GBK编码】
2018/07/17 PHP
Thinkphp整合阿里云OSS图片上传实例代码
2019/04/28 PHP
JQ获取动态加载的图片大小的正确方法分享
2013/11/08 Javascript
jquery插件之定时查询待处理任务数量
2014/05/01 Javascript
如何获取网站icon有哪些可行的方法
2014/06/05 Javascript
基于jQuery实现仿51job城市选择功能实例代码
2016/03/02 Javascript
javascript 组合按键事件监听实现代码
2017/02/21 Javascript
创建简单的node服务器实例(分享)
2017/06/23 Javascript
详解vue2父组件传递props异步数据到子组件的问题
2017/06/29 Javascript
在nginx上部署vue项目(history模式)的方法
2017/12/28 Javascript
详解JavaScript中关于this指向的4种情况
2019/04/18 Javascript
详解将微信小程序接口Promise化并使用async函数
2019/08/05 Javascript
Vue 图片压缩并上传至服务器功能
2020/01/15 Javascript
如何使用vue slot创建一个模态框的实例代码
2020/05/24 Javascript
vant中的toast轻提示实现代码
2020/11/04 Javascript
[01:35]辉夜杯战队访谈宣传片—iG.V
2015/12/25 DOTA
[58:58]2018DOTA2亚洲邀请赛 4.4 淘汰赛 TNC vs VG 第二场
2018/04/05 DOTA
[01:29:42]Liquid vs VP Supermajor决赛 BO 第一场 6.10
2018/07/05 DOTA
[50:12]EG vs Fnatic 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
[01:52]2020年DOTA2 TI10夏季活动预告片
2020/07/15 DOTA
python实现文件名批量替换和内容替换
2014/03/20 Python
建筑安全责任书范本
2014/07/24 职场文书
村党支部书记个人对照材料汇报
2014/10/26 职场文书
党性修养心得体会2016
2016/01/21 职场文书
Ajax常用封装库——Axios的使用
2021/05/08 Javascript
Python道路车道线检测的实现
2021/06/27 Python
MySQL 用 limit 为什么会影响性能
2021/09/15 MySQL
Python 实现Mac 屏幕截图详解
2021/10/05 Python
SQL Server远程连接的设置步骤(图文)
2022/03/23 SQL Server
MySQL详解进行JDBC编程与增删改查方法
2022/06/16 MySQL