TensorFlow实现简单卷积神经网络


Posted in Python onMay 24, 2018

本文使用的数据集是MNIST,主要使用两个卷积层加一个全连接层构建的卷积神经网络。

先载入MNIST数据集(手写数字识别集),并创建默认的Interactive Session(在没有指定回话对象的情况下运行变量)

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
sess = tf.InteractiveSession()

在定义一个初始化函数,因为卷积神经网络有很多权重和偏置需要创建。

def weight_variable(shape): 
 initial = tf.truncated_normal(shape, stddev=0.1)
#给权重制造一些随机的噪声来打破完全对称, 
 return tf.Variable(initial) 
#使用relu,给偏置增加一些小正值0.1,用来避免死亡节点 
def bias_variable(shape): 
 initial = tf.constant(0.1, shape=shape) 
 return tf.Variable(initial)

卷积移动步长都是1代表会不遗漏的划过图片的每一个点,padding代表边界处理方式,same表示给边界加上padding让卷积的输出和输入保持同样的尺寸。

def conv2d(x,W):#2维卷积函数,x输入,w是卷积的参数,strides代表卷积模板移动步长 
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 
       padding='SAME')

在正式设计卷积神经网络结构前,先定义输入的placeholder(类似于c++的cin,要求用户运行时输入)。因为卷积神经网络会利用到空间结构信息,因此需要将一维的输入向量转换为二维的图片结构。同时因为只有一个颜色通道,所以最后尺寸为【-1, 28,28, 1],-1代表样本数量不固定,1代表颜色通道的数量。

这里的tf.reshape是tensor变形函数。

x = tf.placeholder(tf.float32, [None, 784])# x 时特征 
y_ = tf.placeholder(tf.float32, [None, 10])# y_时真实的label 
x_image = tf.reshape(x, [-1, 28, 28,1])

接下来定义第一个卷积层。

w_conv1 = weight_variable([5, 5, 1, 32])
#代表卷积核尺寸为5X5,1个颜色通道,32个不同的卷积核,使用conv2d函数进行卷积操作, 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1)

定义第二个卷积层,与第一个卷积层一样,只不过卷积核的数量变成了64,即这层卷积会提取64种特征

w_conv2 = weight_variable([5, 5, 32, 64])#这层提取64种特征 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2)

经过两次步长为2x2的最大池化,此时图片尺寸变成了7x7,在使用tf.reshape函数,对第二个卷积层的输出tensor进行变形,将其从二维转为一维向量,在连接一个全连接层(隐含节点为1024),使用relu激活函数。

w_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

Dropout层:随机丢弃一部分节点的数据来减轻过拟合。这里是通过一个placeholder传入keep_prob比率来控制的。

#为了减轻过拟合,使用一个Dropout层 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
#dropout层的输出连接一个softmax层,得到最后的概率输出 
w_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)

定义损失函数即评测准确率操作

#损失函数,并且定义优化器为Adam 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), 
            reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

开始训练

#初始化所有参数 
tf.global_variables_initializer().run() 
for i in range (20000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], 
             keep_prob: 1.0}) 
  print("step %d, training accuracy %g"%(i, train_accuracy)) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

全部训练完成后,我们在最终的测试集上进行全面的测试,得到整体的分类准确率。

print("test accuracy %g" %accuracy.eval(feed_dict={ 
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

这个网络,参与训练的样本数量总共为100万,共进行20000次训练迭代,使用大小为50的mini_batch。

TensorFlow实现简单卷积神经网络

因为我安装的版本时CPU版的tensorflow,所以运行较慢,这个模型最终的准确性约为99.2%,基本可以满足对手写数字识别准确率的要求。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
全面了解python字符串和字典
Jul 07 Python
Python实现简单过滤文本段的方法
May 24 Python
利用python求解物理学中的双弹簧质能系统详解
Sep 29 Python
基于scrapy的redis安装和配置方法
Jun 13 Python
Python函数any()和all()的用法及区别介绍
Sep 14 Python
对python读取CT医学图像的实例详解
Jan 24 Python
用python建立两个Y轴的XY曲线图方法
Jul 08 Python
pd.DataFrame统计各列数值多少的实例
Dec 05 Python
Django实现将views.py中的数据传递到前端html页面,并展示
Mar 16 Python
pytorch快速搭建神经网络_Sequential操作
Jun 17 Python
Python模块zipfile原理及使用方法详解
Aug 04 Python
python 第三方库paramiko的常用方式
Feb 20 Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
解决Matplotlib图表不能在Pycharm中显示的问题
May 24 #Python
Python系统监控模块psutil功能与经典用法分析
May 24 #Python
You might like
学习php设计模式 php实现原型模式(prototype)
2015/12/07 PHP
php 5.4 全新的代码复用Trait详解
2017/01/05 PHP
php、mysql查询当天,查询本周,查询本月的数据实例(字段是时间戳)
2017/02/04 PHP
PHP实现简单注册登录系统
2020/12/28 PHP
Ajax::prototype 源码解读
2007/01/22 Javascript
JQuery 学习笔记 选择器之三
2009/07/23 Javascript
jQuery 学习第五课 Ajax 使用说明
2010/05/17 Javascript
简单的前端js+ajax 购物车框架(入门篇)
2011/10/29 Javascript
JS异常处理的一个想法(sofish)
2013/03/14 Javascript
屏蔽script注入小例子
2013/11/12 Javascript
jQuery实现可收缩展开的级联菜单实例代码
2013/11/27 Javascript
jquery序列化form表单使用ajax提交后处理返回的json数据
2014/03/03 Javascript
使用js实现数据格式化
2014/12/03 Javascript
Javascript编写2048小游戏
2015/07/07 Javascript
JS根据生日月份和日期计算星座的简单实现方法
2016/11/24 Javascript
angular2倒计时组件使用详解
2017/01/12 Javascript
js+canvas实现动态吃豆人效果
2017/03/22 Javascript
vue-cli的工程模板与构建工具详解
2018/09/27 Javascript
[34:56]Ti4冒泡赛LGD vs Liquid 1
2014/07/14 DOTA
[06:49]2018DOTA2国际邀请赛寻真——VirtusPro傲视群雄
2018/08/12 DOTA
[46:21]Liquid vs LGD 2018国际邀请赛淘汰赛BO3 第一场 8.23
2018/08/24 DOTA
Python使用内置json模块解析json格式数据的方法
2017/07/20 Python
Python中的二维数组实例(list与numpy.array)
2018/04/13 Python
python整小时 整天时间戳获取算法示例
2019/02/20 Python
python3使用腾讯企业邮箱发送邮件的实例
2019/06/28 Python
Python进程的通信Queue、Pipe实例分析
2020/03/30 Python
如何用python开发Zeroc Ice应用
2021/01/29 Python
python绘制汉诺塔
2021/03/01 Python
css3 transform过渡抖动问题解决
2020/10/23 HTML / CSS
htnl5利用svg页面高斯模糊的方法
2018/07/20 HTML / CSS
Converse匡威法国官网:美国著名帆布鞋品牌
2018/12/05 全球购物
购房意向书
2014/04/01 职场文书
三提三创主题教育活动查摆整改措施
2014/10/25 职场文书
2015年个人现实表现材料
2014/12/10 职场文书
高中班主任工作总结(范文)
2019/08/20 职场文书
MySQL系列之十三 MySQL的复制
2021/07/02 MySQL