TensorFlow实现卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了TensorFlow实现卷积神经网络的具体代码,供大家参考,具体内容如下

代码(源代码都有详细的注释)和数据集可以在github下载:

# -*- coding: utf-8 -*-
'''卷积神经网络测试MNIST数据'''

#########导入MNIST数据########
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 创建默认InteractiveSession
sess = tf.InteractiveSession()


#########卷积网络会有很多的权重和偏置需要创建,先定义好初始化函数以便复用########
# 给权重制造一些随机噪声打破完全对称(比如截断的正态分布噪声,标准差设为0.1)
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)
# 因为我们要使用ReLU,也给偏置增加一些小的正值(0.1)用来避免死亡节点(dead neurons)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)


########卷积层、池化层接下来重复使用的,分别定义创建函数########
# tf.nn.conv2d是TensorFlow中的2维卷积函数
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# 使用2*2的最大池化
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')


########正式设计卷积神经网络之前先定义placeholder########
# x是特征,y_是真实label。将图片数据从1D转为2D。使用tensor的变形函数tf.reshape
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x,[-1,28,28,1])


########设计卷积神经网络########
# 第一层卷积
# 卷积核尺寸为5*5,1个颜色通道,32个不同的卷积核
W_conv1 = weight_variable([5, 5, 1, 32])
# 用conv2d函数进行卷积操作,加上偏置
b_conv1 = bias_variable([32])
# 把x_image和权值向量进行卷积,加上偏置项,然后应用ReLU激活函数,
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
# 对卷积的输出结果进行池化操作
h_pool1 = max_pool_2x2(h_conv1)

# 第二层卷积(和第一层大致相同,卷积核为64,这一层卷积会提取64种特征)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# 全连接层。隐含节点数1024。使用ReLU激活函数
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# 为了防止过拟合,在输出层之前加Dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 输出层。添加一个softmax层,就像softmax regression一样。得到概率输出。
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)


########模型训练设置########
# 定义loss function为cross entropy,优化器使用Adam,并给予一个比较小的学习速率1e-4
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 定义评测准确率的操作
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


########开始训练过程########
# 初始化所有参数
tf.global_variables_initializer().run()

# 训练(设置训练时Dropout的kepp_prob比率为0.5。mini-batch为50,进行2000次迭代训练,参与训练样本5万)
# 其中每进行100次训练,对准确率进行一次评测keep_prob设置为1,用以实时监测模型的性能
for i in range(1000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
  print "-->step %d, training accuracy %.4f"%(i, train_accuracy)
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 全部训练完成之后,在最终测试集上进行全面测试,得到整体的分类准确率
print "卷积神经网络在MNIST数据集正确率: %g"%accuracy.eval(feed_dict={
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

TensorFlow实现卷积神经网络

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Django的模版来配合字符串翻译工作
Jul 27 Python
python生成式的send()方法(详解)
May 08 Python
Python简单计算数组元素平均值的方法示例
Dec 26 Python
python实现图书馆研习室自动预约功能
Apr 27 Python
Python实现的tcp端口检测操作示例
Jul 24 Python
APIStar:一个专为Python3设计的API框架
Sep 26 Python
djang常用查询SQL语句的使用代码
Feb 15 Python
总结python中pass的作用
Feb 27 Python
python 控制Asterisk AMI接口外呼电话的例子
Aug 08 Python
Python利用pip安装tar.gz格式的离线资源包
Sep 14 Python
linux mint中搜狗输入法导致pycharm卡死的问题
Oct 28 Python
Python lxml库的简单介绍及基本使用讲解
Dec 22 Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
You might like
用JTrackBar实现的模拟苹果风格的滚动条
2007/08/06 Javascript
JS 显示当前日期与时间的代码
2010/03/24 Javascript
jquery blockUI 遮罩不能消失与不能提交的解决方法
2011/09/17 Javascript
表单的焦点顺序tabindex和对应enter键提交
2013/01/04 Javascript
js showModalDialog 弹出对话框的简单实例(子窗体)
2014/01/07 Javascript
深入理解JSON数据源格式
2014/01/10 Javascript
元素未显示设置width/height时IE中使用currentStyle获取为auto
2014/05/04 Javascript
javaScript使用EL表达式的几种方式
2014/05/27 Javascript
js replace(a,b)之替换字符串中所有指定字符的方法
2016/08/17 Javascript
js制作支付倒计时页面
2016/10/21 Javascript
React实现点击删除列表中对应项
2017/01/10 Javascript
jquery,js简单实现类似Angular.js双向绑定
2017/01/13 Javascript
webpack 2的react开发配置实例代码
2017/07/28 Javascript
使用Vue开发动态刷新Echarts组件的教程详解
2018/03/22 Javascript
React 全自动数据表格组件——BodeGrid的实现思路
2019/06/12 Javascript
JS轮播图的实现方法2
2020/08/25 Javascript
JavaScript如何实现防止重复的网络请求的示例
2021/01/28 Javascript
Python中实现对list做减法操作介绍
2015/01/09 Python
Python2.x与Python3.x的区别
2016/01/14 Python
Python 专题一 函数的基础知识
2017/03/16 Python
python中的字典操作及字典函数
2018/01/03 Python
Django项目实战之用户头像上传与访问的示例
2018/04/21 Python
python画图系列之个性化显示x轴区段文字的实例
2018/12/13 Python
PyCharm 设置SciView工具窗口的方法
2019/01/15 Python
简单了解python关系(比较)运算符
2019/07/08 Python
python中对_init_的理解及实例解析
2019/10/11 Python
Python基于os.environ从windows获取环境变量
2020/06/09 Python
CSS3打造磨砂玻璃背景效果
2016/09/28 HTML / CSS
html5 touch事件实现触屏页面上下滑动(二)
2016/03/10 HTML / CSS
马来西亚户外装备商店:PTT Outdoor
2019/07/13 全球购物
Java工程师面试集锦之Spring框架
2013/06/16 面试题
竞选班委演讲稿
2014/04/28 职场文书
学生检讨书怎么写
2015/05/07 职场文书
2019年个人工作总结范文
2019/03/25 职场文书
如何在Python中创建二叉树
2021/03/30 Python
详解nginx安装过程并代理下载服务器文件
2022/02/12 Servers