TensorFlow实现卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了TensorFlow实现卷积神经网络的具体代码,供大家参考,具体内容如下

代码(源代码都有详细的注释)和数据集可以在github下载:

# -*- coding: utf-8 -*-
'''卷积神经网络测试MNIST数据'''

#########导入MNIST数据########
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 创建默认InteractiveSession
sess = tf.InteractiveSession()


#########卷积网络会有很多的权重和偏置需要创建,先定义好初始化函数以便复用########
# 给权重制造一些随机噪声打破完全对称(比如截断的正态分布噪声,标准差设为0.1)
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)
# 因为我们要使用ReLU,也给偏置增加一些小的正值(0.1)用来避免死亡节点(dead neurons)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)


########卷积层、池化层接下来重复使用的,分别定义创建函数########
# tf.nn.conv2d是TensorFlow中的2维卷积函数
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# 使用2*2的最大池化
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')


########正式设计卷积神经网络之前先定义placeholder########
# x是特征,y_是真实label。将图片数据从1D转为2D。使用tensor的变形函数tf.reshape
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x,[-1,28,28,1])


########设计卷积神经网络########
# 第一层卷积
# 卷积核尺寸为5*5,1个颜色通道,32个不同的卷积核
W_conv1 = weight_variable([5, 5, 1, 32])
# 用conv2d函数进行卷积操作,加上偏置
b_conv1 = bias_variable([32])
# 把x_image和权值向量进行卷积,加上偏置项,然后应用ReLU激活函数,
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
# 对卷积的输出结果进行池化操作
h_pool1 = max_pool_2x2(h_conv1)

# 第二层卷积(和第一层大致相同,卷积核为64,这一层卷积会提取64种特征)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# 全连接层。隐含节点数1024。使用ReLU激活函数
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# 为了防止过拟合,在输出层之前加Dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 输出层。添加一个softmax层,就像softmax regression一样。得到概率输出。
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)


########模型训练设置########
# 定义loss function为cross entropy,优化器使用Adam,并给予一个比较小的学习速率1e-4
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 定义评测准确率的操作
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


########开始训练过程########
# 初始化所有参数
tf.global_variables_initializer().run()

# 训练(设置训练时Dropout的kepp_prob比率为0.5。mini-batch为50,进行2000次迭代训练,参与训练样本5万)
# 其中每进行100次训练,对准确率进行一次评测keep_prob设置为1,用以实时监测模型的性能
for i in range(1000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
  print "-->step %d, training accuracy %.4f"%(i, train_accuracy)
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 全部训练完成之后,在最终测试集上进行全面测试,得到整体的分类准确率
print "卷积神经网络在MNIST数据集正确率: %g"%accuracy.eval(feed_dict={
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

TensorFlow实现卷积神经网络

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python标准库与第三方库详解
Jul 22 Python
总结Python编程中函数的使用要点
Mar 20 Python
python对list中的每个元素进行某种操作的方法
Jun 29 Python
python实现周期方波信号频谱图
Jul 21 Python
python语音识别实践之百度语音API
Aug 30 Python
Python 分发包中添加额外文件的方法
Aug 16 Python
django 连接数据库出现1045错误的解决方式
May 14 Python
python如何操作mysql
Aug 17 Python
基于Python爬取京东双十一商品价格曲线
Oct 23 Python
python将下载到本地m3u8视频合成MP4的代码详解
Nov 24 Python
Python的Tqdm模块实现进度条配置
Feb 24 Python
python包的导入方式总结
Mar 02 Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
You might like
深入了解php4(2)--重访过去
2006/10/09 PHP
一个简洁的多级别论坛
2006/10/09 PHP
PHP连接SQLServer2005方法及代码
2013/12/26 PHP
PHP中的gzcompress、gzdeflate、gzencode函数详解
2014/07/29 PHP
PHP将Excel导入数据库及数据库数据导出至Excel的方法
2015/06/24 PHP
网页的分页下标生成代码(PHP后端方法)
2016/02/03 PHP
PHP调用API接口实现天气查询功能的示例
2017/09/21 PHP
PHP开发API接口签名生成及验证操作示例
2020/05/27 PHP
javascript实现的listview效果
2007/04/28 Javascript
深入理解JavaScript系列(11) 执行上下文(Execution Contexts)
2012/01/15 Javascript
浅析JavaScript中的typeof运算符
2013/11/30 Javascript
探讨JavaScript中声明全局变量三种方式的异同
2013/12/03 Javascript
document.execCommand()的用法小结
2014/01/08 Javascript
在easyUI开发中,出现jquery.easyui.min.js函数库问题的解决办法
2015/09/11 Javascript
Angularjs---项目搭建图文教程
2016/07/08 Javascript
如何用JS判断两个数字的大小
2016/07/21 Javascript
关于List.ToArray()方法的效率测试
2016/09/30 Javascript
React利用插件和不用插件实现双向绑定的方法详解
2017/07/03 Javascript
Webpack 之 babel-loader文件预处理器详解
2018/03/23 Javascript
微信小程序websocket聊天室的实现示例代码
2019/02/12 Javascript
Python实现抓取网页并且解析的实例
2014/09/20 Python
Python读取图片属性信息的实现方法
2016/09/11 Python
Python读取指定目录下指定后缀文件并保存为docx
2017/04/23 Python
python-str,list,set间的转换实例
2018/06/27 Python
浅谈python中str字符串和unicode对象字符串的拼接问题
2018/12/04 Python
python实现三维拟合的方法
2018/12/29 Python
使用python和pygame制作挡板弹球游戏
2019/12/03 Python
基于python实现复制文件并重命名
2020/09/16 Python
用ldap作为django后端用户登录验证的实现
2020/12/07 Python
机电专业毕业生求职信
2013/10/27 职场文书
法制宣传月活动方案
2014/05/11 职场文书
展览会邀请函
2015/02/02 职场文书
趣味运动会通讯稿
2015/07/18 职场文书
只需要12页,掌握撰写一流商业计划书的技巧
2019/05/07 职场文书
探讨Java中的深浅拷贝问题
2021/06/26 Java/Android
Go语言实现Base64、Base58编码与解码
2021/07/26 Golang