TensorFlow实现简单卷积神经网络


Posted in Python onMay 24, 2018

本文使用的数据集是MNIST,主要使用两个卷积层加一个全连接层构建的卷积神经网络。

先载入MNIST数据集(手写数字识别集),并创建默认的Interactive Session(在没有指定回话对象的情况下运行变量)

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
sess = tf.InteractiveSession()

在定义一个初始化函数,因为卷积神经网络有很多权重和偏置需要创建。

def weight_variable(shape): 
 initial = tf.truncated_normal(shape, stddev=0.1)
#给权重制造一些随机的噪声来打破完全对称, 
 return tf.Variable(initial) 
#使用relu,给偏置增加一些小正值0.1,用来避免死亡节点 
def bias_variable(shape): 
 initial = tf.constant(0.1, shape=shape) 
 return tf.Variable(initial)

卷积移动步长都是1代表会不遗漏的划过图片的每一个点,padding代表边界处理方式,same表示给边界加上padding让卷积的输出和输入保持同样的尺寸。

def conv2d(x,W):#2维卷积函数,x输入,w是卷积的参数,strides代表卷积模板移动步长 
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 
       padding='SAME')

在正式设计卷积神经网络结构前,先定义输入的placeholder(类似于c++的cin,要求用户运行时输入)。因为卷积神经网络会利用到空间结构信息,因此需要将一维的输入向量转换为二维的图片结构。同时因为只有一个颜色通道,所以最后尺寸为【-1, 28,28, 1],-1代表样本数量不固定,1代表颜色通道的数量。

这里的tf.reshape是tensor变形函数。

x = tf.placeholder(tf.float32, [None, 784])# x 时特征 
y_ = tf.placeholder(tf.float32, [None, 10])# y_时真实的label 
x_image = tf.reshape(x, [-1, 28, 28,1])

接下来定义第一个卷积层。

w_conv1 = weight_variable([5, 5, 1, 32])
#代表卷积核尺寸为5X5,1个颜色通道,32个不同的卷积核,使用conv2d函数进行卷积操作, 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1)

定义第二个卷积层,与第一个卷积层一样,只不过卷积核的数量变成了64,即这层卷积会提取64种特征

w_conv2 = weight_variable([5, 5, 32, 64])#这层提取64种特征 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2)

经过两次步长为2x2的最大池化,此时图片尺寸变成了7x7,在使用tf.reshape函数,对第二个卷积层的输出tensor进行变形,将其从二维转为一维向量,在连接一个全连接层(隐含节点为1024),使用relu激活函数。

w_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

Dropout层:随机丢弃一部分节点的数据来减轻过拟合。这里是通过一个placeholder传入keep_prob比率来控制的。

#为了减轻过拟合,使用一个Dropout层 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
#dropout层的输出连接一个softmax层,得到最后的概率输出 
w_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)

定义损失函数即评测准确率操作

#损失函数,并且定义优化器为Adam 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), 
            reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

开始训练

#初始化所有参数 
tf.global_variables_initializer().run() 
for i in range (20000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], 
             keep_prob: 1.0}) 
  print("step %d, training accuracy %g"%(i, train_accuracy)) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

全部训练完成后,我们在最终的测试集上进行全面的测试,得到整体的分类准确率。

print("test accuracy %g" %accuracy.eval(feed_dict={ 
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

这个网络,参与训练的样本数量总共为100万,共进行20000次训练迭代,使用大小为50的mini_batch。

TensorFlow实现简单卷积神经网络

因为我安装的版本时CPU版的tensorflow,所以运行较慢,这个模型最终的准确性约为99.2%,基本可以满足对手写数字识别准确率的要求。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现linux服务器批量修改密码并生成execl
Apr 22 Python
高性能web服务器框架Tornado简单实现restful接口及开发实例
Jul 16 Python
python读取视频流提取视频帧的两种方法
Oct 22 Python
Windows下python3.7安装教程
Jul 31 Python
对python 读取线的shp文件实例详解
Dec 22 Python
Puppeteer使用示例详解
Jun 20 Python
python爬虫selenium和phantomJs使用方法解析
Aug 08 Python
Python下利用BeautifulSoup解析HTML的实现
Jan 17 Python
快速解决jupyter notebook启动需要密码的问题
Apr 21 Python
Python编写memcached启动脚本代码实例
Aug 14 Python
Python爬虫实现自动登录、签到功能的代码
Aug 20 Python
编写python代码实现简单抽奖器
Oct 20 Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
解决Matplotlib图表不能在Pycharm中显示的问题
May 24 #Python
Python系统监控模块psutil功能与经典用法分析
May 24 #Python
You might like
投票管理程序
2006/10/09 PHP
ThinkPHP 连接Oracle数据库的详细教程[全]
2012/07/16 PHP
PHP基于php_imagick_st-Q8.dll实现JPG合成GIF图片的方法
2014/07/11 PHP
php通过获取头信息判断图片类型的方法
2015/06/26 PHP
Laravle eloquent 多对多模型关联实例详解
2017/11/22 PHP
php中上传文件的的解决方案
2018/09/25 PHP
aspx中利用js实现确认删除代码
2010/07/22 Javascript
js限制文本框只能输入数字(正则表达式)
2012/07/15 Javascript
解析Jquery中如何把一段html代码动态写入到DIV中(实例说明)
2013/07/09 Javascript
javascript动态控制服务器控件实例
2014/09/05 Javascript
javascript实现动态表头及表列的展现方法
2015/07/14 Javascript
七个不允许错过的jQuery小技巧
2015/12/21 Javascript
jQuery Ajax 异步加载显示等待效果代码分享
2016/08/01 Javascript
JavaScript中 ES6 generator数据类型详解
2016/08/11 Javascript
AngularJS中的JSONP实例解析
2016/12/01 Javascript
最常用的jQuery表单验证(简单)
2017/05/23 jQuery
js实现扫雷小程序的示例代码
2017/09/27 Javascript
Vue父组件向子组件传值以及data和props的区别详解
2020/03/02 Javascript
分享一款超好用的JavaScript 打包压缩工具
2020/04/26 Javascript
基于vue+echarts数据可视化大屏展示的实现
2020/12/25 Vue.js
[03:03]DOTA2校园争霸赛 济南城市决赛欢乐发奖活动
2013/10/21 DOTA
[00:10]DOTA2全国高校联赛速递
2018/05/30 DOTA
Python中注释(多行注释和单行注释)的用法实例
2019/08/28 Python
python通过opencv实现图片裁剪原理解析
2020/01/19 Python
Python分析最近大火的网剧《隐秘的角落》
2020/07/02 Python
Python3爬虫中Selenium的用法详解
2020/07/10 Python
html5手机键盘弹出收起的处理
2020/01/20 HTML / CSS
德国原装品牌香水、化妆品和手表网站:BRASTY.DE
2016/10/16 全球购物
小学教师培训感言
2014/02/11 职场文书
环保小标语
2014/06/13 职场文书
学风建设演讲稿
2014/09/12 职场文书
卖房协议书样本
2014/10/30 职场文书
建筑工程挂靠协议书
2016/03/23 职场文书
入党心得体会
2019/06/20 职场文书
研究生学习计划书应该怎么写?
2019/09/10 职场文书
Windows server 2012 NTP时间同步的实现
2022/06/25 Servers