python tensorflow基于cnn实现手写数字识别


Posted in Python onJanuary 01, 2018

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 加载数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 以交互式方式启动session
# 如果不使用交互式session,则在启动session前必须
# 构建整个计算图,才能启动该计算图
sess = tf.InteractiveSession()

"""构建计算图"""
# 通过占位符来为输入图像和目标输出类别创建节点
# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误
x = tf.placeholder("float", shape=[None, 784]) # 原始输入
y_ = tf.placeholder("float", shape=[None, 10]) # 目标值

# 为了不在建立模型的时候反复做初始化操作,
# 我们定义两个函数用于初始化
def weight_variable(shape):
 # 截尾正态分布,stddev是正态分布的标准偏差
 initial = tf.truncated_normal(shape=shape, stddev=0.1)
 return tf.Variable(initial)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

# 卷积核池化,步长为1,0边距
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
       strides=[1, 2, 2, 1], padding='SAME')

"""第一层卷积"""
# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积
# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一个输出通道都有一个偏置量
b_conv1 = bias_variable([32])

# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高
# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)
x_image = tf.reshape(x, [-1, 28, 28, 1])

# 第一层的卷积结果,使用Relu作为激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一层卷积后的池化结果
h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""
# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 将最后的池化层输出张量reshape成一维向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""
# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率
# 在训练的过程中启用dropout,在测试过程中关闭dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型预测输出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# 交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

# 模型训练,使用AdamOptimizer来做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 正确预测,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 将布尔值转化成浮点数,取平均值作为精确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 在session中先初始化变量才能在session中调用
sess.run(tf.global_variables_initializer())

# 迭代优化模型
for i in range(2000):
 # 每次取50个样本进行训练
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={
   x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout
  print("step %d, training accuracy %g" % (i, train_accuracy))
 train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
   x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

做了2000次迭代,在测试集上的识别精度能够到0.9772……

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
零基础写python爬虫之urllib2使用指南
Nov 05 Python
在Django框架中编写Context处理器的方法
Jul 20 Python
python 实现网上商城,转账,存取款等功能的信用卡系统
Jul 15 Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 Python
Python之dict(或对象)与json之间的互相转化实例
Jun 05 Python
在Pycharm terminal中字体大小设置的方法
Jan 16 Python
通过字符串导入 Python 模块的方法详解
Oct 27 Python
python 实现绘制整齐的表格
Nov 18 Python
在pycharm中为项目导入anacodna环境的操作方法
Feb 12 Python
django实现HttpResponse返回json数据为中文
Mar 27 Python
sqlalchemy实现时间列自动更新教程
Sep 02 Python
Python爬虫教程之利用正则表达式匹配网页内容
Dec 08 Python
python+selenium实现163邮箱自动登陆的方法
Dec 31 #Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 #Python
利用python将图片转换成excel文档格式
Dec 30 #Python
书单|人生苦短,你还不用python!
Dec 29 #Python
python ansible服务及剧本编写
Dec 29 #Python
详解python 拆包可迭代数据如tuple, list
Dec 29 #Python
详解Python异常处理中的Finally else的功能
Dec 29 #Python
You might like
咖啡知识 咖啡养豆要养多久 排气又是什么
2021/03/06 新手入门
PHP系统流量分析的程序
2006/10/09 PHP
PHP 工厂模式使用方法
2010/05/18 PHP
PHP的serialize序列化数据以及JSON格式化数据分析
2015/10/10 PHP
PHP正则表达式入门教程(推荐)
2016/05/18 PHP
PHP PDOStatement::closeCursor讲解
2019/01/30 PHP
将json对象转换为字符串的方法
2014/02/20 Javascript
jquery实现翻动fadeIn显示的方法
2015/03/05 Javascript
javascript制作幻灯片(360度全景图片)
2015/07/28 Javascript
jquery分隔Url的param方法(推荐)
2016/05/25 Javascript
JS 日期与时间戮相互转化的简单实例
2016/06/22 Javascript
微信小程序scroll-view实现横向滚动和上拉加载示例
2017/03/06 Javascript
了解VUE的render函数的使用
2017/06/08 Javascript
jQuery实现base64前台加密解密功能详解
2017/08/29 jQuery
Webpack实战加载SVG的方法
2017/12/26 Javascript
使用webpack搭建vue环境的教程详解
2019/12/31 Javascript
[01:26]神话结束了,却也刚刚开始——DOTA2新英雄玛尔斯驾临战场
2019/03/10 DOTA
python 切片和range()用法说明
2013/03/24 Python
在 Python 应用中使用 MongoDB的方法
2017/01/05 Python
详解python之多进程和进程池(Processing库)
2017/06/09 Python
python的pdb调试命令的命令整理及实例
2017/07/12 Python
浅谈python中的数字类型与处理工具
2017/08/02 Python
python机器学习实战之K均值聚类
2017/12/20 Python
python版大富翁源代码分享
2018/11/19 Python
Django对models里的objects的使用详解
2019/08/17 Python
Python reversed函数及使用方法解析
2020/03/17 Python
python输入一个水仙花数(三位数) 输出百位十位个位实例
2020/05/03 Python
浅谈多卡服务器下隐藏部分 GPU 和 TensorFlow 的显存使用设置
2020/06/30 Python
使用html5 canvas绘制圆环动效
2019/06/03 HTML / CSS
C语言基础笔试题
2013/04/27 面试题
化工操作工岗位职责
2014/04/29 职场文书
保护环境建议书100字
2014/05/13 职场文书
党校毕业心得体会
2014/09/13 职场文书
2014年新农村建设工作总结
2014/12/01 职场文书
工会经费申请报告
2015/05/15 职场文书
python和C/C++混合编程之使用ctypes调用 C/C++的dll
2022/04/29 Python