python tensorflow基于cnn实现手写数字识别


Posted in Python onJanuary 01, 2018

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 加载数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 以交互式方式启动session
# 如果不使用交互式session,则在启动session前必须
# 构建整个计算图,才能启动该计算图
sess = tf.InteractiveSession()

"""构建计算图"""
# 通过占位符来为输入图像和目标输出类别创建节点
# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误
x = tf.placeholder("float", shape=[None, 784]) # 原始输入
y_ = tf.placeholder("float", shape=[None, 10]) # 目标值

# 为了不在建立模型的时候反复做初始化操作,
# 我们定义两个函数用于初始化
def weight_variable(shape):
 # 截尾正态分布,stddev是正态分布的标准偏差
 initial = tf.truncated_normal(shape=shape, stddev=0.1)
 return tf.Variable(initial)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

# 卷积核池化,步长为1,0边距
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
       strides=[1, 2, 2, 1], padding='SAME')

"""第一层卷积"""
# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积
# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一个输出通道都有一个偏置量
b_conv1 = bias_variable([32])

# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高
# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)
x_image = tf.reshape(x, [-1, 28, 28, 1])

# 第一层的卷积结果,使用Relu作为激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一层卷积后的池化结果
h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""
# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 将最后的池化层输出张量reshape成一维向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""
# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率
# 在训练的过程中启用dropout,在测试过程中关闭dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型预测输出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# 交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

# 模型训练,使用AdamOptimizer来做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 正确预测,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 将布尔值转化成浮点数,取平均值作为精确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 在session中先初始化变量才能在session中调用
sess.run(tf.global_variables_initializer())

# 迭代优化模型
for i in range(2000):
 # 每次取50个样本进行训练
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={
   x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout
  print("step %d, training accuracy %g" % (i, train_accuracy))
 train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
   x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

做了2000次迭代,在测试集上的识别精度能够到0.9772……

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用SAX解析xml实例
Nov 21 Python
Python计算回文数的方法
Mar 11 Python
Python版微信红包分配算法
May 04 Python
python更新列表的方法
Jul 28 Python
使用Python解析JSON数据的基本方法
Oct 15 Python
实例讲解Python设计模式编程之工厂方法模式的使用
Mar 02 Python
Python字典数据对象拆分的简单实现方法
Dec 05 Python
python3.5+tesseract+adb实现西瓜视频或头脑王者辅助答题
Jan 17 Python
Python字符串中删除特定字符的方法
Jan 15 Python
python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法
Apr 22 Python
python 利用百度API识别图片文字(多线程版)
Dec 14 Python
python的setattr函数实例用法
Dec 16 Python
python+selenium实现163邮箱自动登陆的方法
Dec 31 #Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 #Python
利用python将图片转换成excel文档格式
Dec 30 #Python
书单|人生苦短,你还不用python!
Dec 29 #Python
python ansible服务及剧本编写
Dec 29 #Python
详解python 拆包可迭代数据如tuple, list
Dec 29 #Python
详解Python异常处理中的Finally else的功能
Dec 29 #Python
You might like
PHP下10件你也许并不了解的事情
2008/09/11 PHP
Uchome1.2 1.5 代码学习 common.php
2009/04/24 PHP
PHP 批量删除数据的方法分析
2009/10/30 PHP
PHP下利用shell后台运行PHP脚本,并获取该脚本的Process ID的代码
2011/09/19 PHP
php类的自动加载操作实例详解
2016/09/28 PHP
基于jQueryUI和Corethink实现百度的搜索提示功能
2016/11/09 PHP
window.parent调用父框架时 ie跟火狐不兼容问题
2009/07/30 Javascript
推荐10个超棒的jQuery工具提示插件
2011/10/11 Javascript
javascript中获取下个月一号,是星期几
2012/06/01 Javascript
cookie.js 加载顺序问题怎么才有效
2013/07/31 Javascript
node.js中的buffer.fill方法使用说明
2014/12/14 Javascript
jqueryMobile使用示例分享
2016/01/12 Javascript
JavaScript程序中的流程控制语句用法总结
2016/05/23 Javascript
jQuery validate插件功能与用法详解
2016/12/15 Javascript
详解AngularJS用Interceptors来统一处理HTTP请求和响应
2017/06/08 Javascript
nodejs结合socket.io实现websocket通信功能的方法
2018/01/12 NodeJs
webpack+vue-cil中proxyTable处理跨域的方法
2018/07/20 Javascript
微信小程序用户拒绝授权的处理方法详解
2019/09/20 Javascript
JavaScript 如何计算文本的行数的实现
2020/09/14 Javascript
Python对list列表结构中的值进行去重的方法总结
2016/05/07 Python
Python中的迭代器与生成器高级用法解析
2016/06/28 Python
python基于twisted框架编写简单聊天室
2018/01/02 Python
python实现列表中由数值查到索引的方法
2018/06/27 Python
python框架Django实战商城项目之工程搭建过程图文详解
2020/03/09 Python
Python:__eq__和__str__函数的使用示例
2020/09/26 Python
Python的3种运行方式:命令行窗口、Python解释器、IDLE的实现
2020/10/10 Python
Pycharm操作Git及GitHub的步骤详解
2020/10/27 Python
selenium框架中driver.close()和driver.quit()关闭浏览器
2020/12/08 Python
目前不被任何主流浏览器支持的CSS3属性汇总
2014/07/21 HTML / CSS
美国知名保健品网站:LuckyVitamin(支持中文)
2017/08/09 全球购物
美国背景检查、公共记录和人物搜索网站:BeenVerified
2018/02/25 全球购物
亚马逊新加坡官方网站:Amazon.sg
2020/03/25 全球购物
个人自我剖析材料
2014/09/30 职场文书
使用react-virtualized实现图片动态高度长列表的问题
2021/05/28 Javascript
5人制售《绝地求生》游戏外挂获利500多万元 被判刑
2022/03/31 其他游戏
MySQL中优化SQL语句的方法(show status、explain分析服务器状态信息)
2022/04/09 MySQL