TensorFlow实现简单卷积神经网络


Posted in Python onMay 24, 2018

本文使用的数据集是MNIST,主要使用两个卷积层加一个全连接层构建的卷积神经网络。

先载入MNIST数据集(手写数字识别集),并创建默认的Interactive Session(在没有指定回话对象的情况下运行变量)

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
sess = tf.InteractiveSession()

在定义一个初始化函数,因为卷积神经网络有很多权重和偏置需要创建。

def weight_variable(shape): 
 initial = tf.truncated_normal(shape, stddev=0.1)
#给权重制造一些随机的噪声来打破完全对称, 
 return tf.Variable(initial) 
#使用relu,给偏置增加一些小正值0.1,用来避免死亡节点 
def bias_variable(shape): 
 initial = tf.constant(0.1, shape=shape) 
 return tf.Variable(initial)

卷积移动步长都是1代表会不遗漏的划过图片的每一个点,padding代表边界处理方式,same表示给边界加上padding让卷积的输出和输入保持同样的尺寸。

def conv2d(x,W):#2维卷积函数,x输入,w是卷积的参数,strides代表卷积模板移动步长 
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 
       padding='SAME')

在正式设计卷积神经网络结构前,先定义输入的placeholder(类似于c++的cin,要求用户运行时输入)。因为卷积神经网络会利用到空间结构信息,因此需要将一维的输入向量转换为二维的图片结构。同时因为只有一个颜色通道,所以最后尺寸为【-1, 28,28, 1],-1代表样本数量不固定,1代表颜色通道的数量。

这里的tf.reshape是tensor变形函数。

x = tf.placeholder(tf.float32, [None, 784])# x 时特征 
y_ = tf.placeholder(tf.float32, [None, 10])# y_时真实的label 
x_image = tf.reshape(x, [-1, 28, 28,1])

接下来定义第一个卷积层。

w_conv1 = weight_variable([5, 5, 1, 32])
#代表卷积核尺寸为5X5,1个颜色通道,32个不同的卷积核,使用conv2d函数进行卷积操作, 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1)

定义第二个卷积层,与第一个卷积层一样,只不过卷积核的数量变成了64,即这层卷积会提取64种特征

w_conv2 = weight_variable([5, 5, 32, 64])#这层提取64种特征 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2)

经过两次步长为2x2的最大池化,此时图片尺寸变成了7x7,在使用tf.reshape函数,对第二个卷积层的输出tensor进行变形,将其从二维转为一维向量,在连接一个全连接层(隐含节点为1024),使用relu激活函数。

w_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

Dropout层:随机丢弃一部分节点的数据来减轻过拟合。这里是通过一个placeholder传入keep_prob比率来控制的。

#为了减轻过拟合,使用一个Dropout层 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
#dropout层的输出连接一个softmax层,得到最后的概率输出 
w_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)

定义损失函数即评测准确率操作

#损失函数,并且定义优化器为Adam 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), 
            reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

开始训练

#初始化所有参数 
tf.global_variables_initializer().run() 
for i in range (20000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], 
             keep_prob: 1.0}) 
  print("step %d, training accuracy %g"%(i, train_accuracy)) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

全部训练完成后,我们在最终的测试集上进行全面的测试,得到整体的分类准确率。

print("test accuracy %g" %accuracy.eval(feed_dict={ 
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

这个网络,参与训练的样本数量总共为100万,共进行20000次训练迭代,使用大小为50的mini_batch。

TensorFlow实现简单卷积神经网络

因为我安装的版本时CPU版的tensorflow,所以运行较慢,这个模型最终的准确性约为99.2%,基本可以满足对手写数字识别准确率的要求。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 域名分析工具实现代码
Jul 15 Python
Python程序设计入门(5)类的使用简介
Jun 16 Python
使用Python的内建模块collections的教程
Apr 28 Python
日常整理python执行系统命令的常见方法(全)
Oct 22 Python
Python的几个高级语法概念浅析(lambda表达式闭包装饰器)
May 28 Python
Python基础语言学习笔记总结(精华)
Nov 14 Python
python正则实现提取电话功能
Feb 24 Python
django 做 migrate 时 表已存在的处理方法
Aug 31 Python
IronPython连接MySQL的方法步骤
Dec 27 Python
简单的Python人脸识别系统
Jul 14 Python
python request 模块详细介绍
Nov 10 Python
Python环境搭建过程从安装到Hello World
Feb 05 Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
解决Matplotlib图表不能在Pycharm中显示的问题
May 24 #Python
Python系统监控模块psutil功能与经典用法分析
May 24 #Python
You might like
以文本方式上传二进制文件的PHP程序
2006/10/09 PHP
PHP 组件化编程技巧
2009/06/06 PHP
ThinkPHP 防止表单重复提交的方法
2011/08/08 PHP
基于PHP CURL用法的深入分析
2013/06/09 PHP
PHP检测用户语言的方法
2015/06/15 PHP
php自动更新版权信息显示的方法
2015/06/19 PHP
extJS中常用的4种Ajax异步提交方式
2014/03/07 Javascript
javascript记录文本框内文字个数检测文字个数变化
2014/10/14 Javascript
jQuery实现checkbox全选的方法
2015/06/10 Javascript
JQuery validate插件Remote用法大全
2016/05/15 Javascript
jQuery 翻页组件yunm.pager.js实现div局部刷新的思路
2016/08/11 Javascript
js 开发之autocomplete="off"在chrom中失效的解决办法
2017/09/28 Javascript
基于vue监听滚动事件实现锚点链接平滑滚动的方法
2018/01/17 Javascript
浅谈webpack打包过程中因为图片的路径导致的问题
2018/02/21 Javascript
如何理解Vue的v-model指令的使用方法
2018/07/19 Javascript
JavaScript实现的前端AES加密解密功能【基于CryptoJS】
2018/08/28 Javascript
微信小程序实现下拉菜单切换效果
2020/03/30 Javascript
vue中使用带隐藏文本信息的图片、图片水印的方法
2020/04/24 Javascript
vue render函数动态加载img的src路径操作
2020/10/26 Javascript
Python检测QQ在线状态的方法
2015/05/09 Python
python爬虫入门教程--HTML文本的解析库BeautifulSoup(四)
2017/05/25 Python
Python批量提取PDF文件中文本的脚本
2018/03/14 Python
解决python中画图时x,y轴名称出现中文乱码的问题
2019/01/29 Python
Python对ElasticSearch获取数据及操作
2019/04/24 Python
python实现在cmd窗口显示彩色文字
2019/06/24 Python
Python Web静态服务器非堵塞模式实现方法示例
2019/11/21 Python
Python绘制全球疫情变化地图的实例代码
2020/04/20 Python
奥林匹克的口号
2014/06/13 职场文书
建筑工地标语
2014/06/18 职场文书
高效课堂标语
2014/06/26 职场文书
社团活动总结模板
2014/06/30 职场文书
2015年设计师个人工作总结
2015/04/25 职场文书
惊涛骇浪观后感
2015/06/05 职场文书
瞿秋白纪念馆观后感
2015/06/10 职场文书
Java中多线程下载图片并压缩能提高效率吗
2021/07/01 Java/Android
golang为什么要统一错误处理
2022/04/03 Golang