TensorFlow实现简单卷积神经网络


Posted in Python onMay 24, 2018

本文使用的数据集是MNIST,主要使用两个卷积层加一个全连接层构建的卷积神经网络。

先载入MNIST数据集(手写数字识别集),并创建默认的Interactive Session(在没有指定回话对象的情况下运行变量)

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
sess = tf.InteractiveSession()

在定义一个初始化函数,因为卷积神经网络有很多权重和偏置需要创建。

def weight_variable(shape): 
 initial = tf.truncated_normal(shape, stddev=0.1)
#给权重制造一些随机的噪声来打破完全对称, 
 return tf.Variable(initial) 
#使用relu,给偏置增加一些小正值0.1,用来避免死亡节点 
def bias_variable(shape): 
 initial = tf.constant(0.1, shape=shape) 
 return tf.Variable(initial)

卷积移动步长都是1代表会不遗漏的划过图片的每一个点,padding代表边界处理方式,same表示给边界加上padding让卷积的输出和输入保持同样的尺寸。

def conv2d(x,W):#2维卷积函数,x输入,w是卷积的参数,strides代表卷积模板移动步长 
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 
       padding='SAME')

在正式设计卷积神经网络结构前,先定义输入的placeholder(类似于c++的cin,要求用户运行时输入)。因为卷积神经网络会利用到空间结构信息,因此需要将一维的输入向量转换为二维的图片结构。同时因为只有一个颜色通道,所以最后尺寸为【-1, 28,28, 1],-1代表样本数量不固定,1代表颜色通道的数量。

这里的tf.reshape是tensor变形函数。

x = tf.placeholder(tf.float32, [None, 784])# x 时特征 
y_ = tf.placeholder(tf.float32, [None, 10])# y_时真实的label 
x_image = tf.reshape(x, [-1, 28, 28,1])

接下来定义第一个卷积层。

w_conv1 = weight_variable([5, 5, 1, 32])
#代表卷积核尺寸为5X5,1个颜色通道,32个不同的卷积核,使用conv2d函数进行卷积操作, 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1)

定义第二个卷积层,与第一个卷积层一样,只不过卷积核的数量变成了64,即这层卷积会提取64种特征

w_conv2 = weight_variable([5, 5, 32, 64])#这层提取64种特征 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2)

经过两次步长为2x2的最大池化,此时图片尺寸变成了7x7,在使用tf.reshape函数,对第二个卷积层的输出tensor进行变形,将其从二维转为一维向量,在连接一个全连接层(隐含节点为1024),使用relu激活函数。

w_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

Dropout层:随机丢弃一部分节点的数据来减轻过拟合。这里是通过一个placeholder传入keep_prob比率来控制的。

#为了减轻过拟合,使用一个Dropout层 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
#dropout层的输出连接一个softmax层,得到最后的概率输出 
w_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)

定义损失函数即评测准确率操作

#损失函数,并且定义优化器为Adam 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), 
            reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

开始训练

#初始化所有参数 
tf.global_variables_initializer().run() 
for i in range (20000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], 
             keep_prob: 1.0}) 
  print("step %d, training accuracy %g"%(i, train_accuracy)) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

全部训练完成后,我们在最终的测试集上进行全面的测试,得到整体的分类准确率。

print("test accuracy %g" %accuracy.eval(feed_dict={ 
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

这个网络,参与训练的样本数量总共为100万,共进行20000次训练迭代,使用大小为50的mini_batch。

TensorFlow实现简单卷积神经网络

因为我安装的版本时CPU版的tensorflow,所以运行较慢,这个模型最终的准确性约为99.2%,基本可以满足对手写数字识别准确率的要求。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的map、reduce和filter浅析
Apr 26 Python
Python最长公共子串算法实例
Mar 07 Python
Python实现对比不同字体中的同一字符的显示效果
Apr 23 Python
详解Python Socket网络编程
Jan 05 Python
python bottle框架支持jquery ajax的RESTful风格的PUT和DELETE方法
May 24 Python
Python解决N阶台阶走法问题的方法分析
Dec 28 Python
快速了解Python相对导入
Jan 12 Python
ubuntu17.4下为python和python3装上pip的方法
Jun 12 Python
Python编程中类与类的关系详解
Aug 08 Python
Django 自定义分页器的实现代码
Nov 24 Python
使用Python实现将多表分批次从数据库导出到Excel
May 15 Python
聊聊Python中关于a=[[]]*3的反思
Jun 02 Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
解决Matplotlib图表不能在Pycharm中显示的问题
May 24 #Python
Python系统监控模块psutil功能与经典用法分析
May 24 #Python
You might like
composer.lock文件的作用
2016/02/03 PHP
PHP实现二叉树深度优先遍历(前序、中序、后序)和广度优先遍历(层次)实例详解
2018/04/20 PHP
window resize和scroll事件的基本优化思路
2014/04/29 Javascript
jquery插件unobtrusive实现片段式加载
2015/06/15 Javascript
AngularJS通过$location获取及改变当前页面的URL
2016/09/23 Javascript
js当前页面登录注册框,固定div,底层阴影的实例代码
2016/10/04 Javascript
express文件上传中间件Multer详解
2016/10/24 Javascript
JavaScript常见的五种数组去重的方式
2016/12/15 Javascript
AngularJS报错$apply already in progress的解决方法分析
2017/01/30 Javascript
微信小程序实战之顶部导航栏(选项卡)(1)
2020/06/19 Javascript
JavaScript学习笔记之函数记忆
2017/09/06 Javascript
vue双花括号的使用方法 附练习题
2017/11/07 Javascript
Element输入框带历史查询记录的实现示例
2019/01/15 Javascript
从理论角度讨论JavaScript闭包
2019/04/03 Javascript
Vue函数式组件的应用实例详解
2019/08/30 Javascript
Vue的状态管理vuex使用方法详解
2020/02/05 Javascript
JavaScript 如何在浏览器中使用摄像头
2020/12/02 Javascript
python3.0 字典key排序
2008/12/24 Python
Python类属性与实例属性用法分析
2015/05/09 Python
用python写的一个wordpress的采集程序
2016/02/27 Python
PyCharm设置每行最大长度限制的方法
2019/01/16 Python
python安装scipy的方法步骤
2019/06/26 Python
对python中UDP,socket的使用详解
2019/08/22 Python
python将print输出的信息保留到日志文件中
2019/09/27 Python
关于Numpy数据类型对象(dtype)使用详解
2019/11/27 Python
Python使用PyQt5/PySide2编写一个极简的音乐播放器功能
2020/02/07 Python
python logging设置level失败的解决方法
2020/02/19 Python
HTML5 video标签(播放器)学习笔记(一):使用入门
2015/04/24 HTML / CSS
爱淘宝:淘宝网购物分享平台
2017/04/28 全球购物
Capitol Lighting的1800lighting.com:住宅和商业照明
2019/04/10 全球购物
Ray-Ban雷朋奥地利官网:全球领先的太阳眼镜品牌
2020/10/12 全球购物
写一个函数,求一个字符串的长度。在main函数中输入字符串,并输出其长度
2015/11/18 面试题
linux面试题参考答案(6)
2014/08/29 面试题
如何用Java实现列出某个目录下的所有子目录
2015/07/20 面试题
《小小竹排画中游》教学反思
2014/02/26 职场文书
mysql中DCL常用的用户和权限控制
2022/03/31 MySQL