TensorFlow实现卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了TensorFlow实现卷积神经网络的具体代码,供大家参考,具体内容如下

代码(源代码都有详细的注释)和数据集可以在github下载:

# -*- coding: utf-8 -*-
'''卷积神经网络测试MNIST数据'''

#########导入MNIST数据########
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 创建默认InteractiveSession
sess = tf.InteractiveSession()


#########卷积网络会有很多的权重和偏置需要创建,先定义好初始化函数以便复用########
# 给权重制造一些随机噪声打破完全对称(比如截断的正态分布噪声,标准差设为0.1)
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)
# 因为我们要使用ReLU,也给偏置增加一些小的正值(0.1)用来避免死亡节点(dead neurons)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)


########卷积层、池化层接下来重复使用的,分别定义创建函数########
# tf.nn.conv2d是TensorFlow中的2维卷积函数
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# 使用2*2的最大池化
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')


########正式设计卷积神经网络之前先定义placeholder########
# x是特征,y_是真实label。将图片数据从1D转为2D。使用tensor的变形函数tf.reshape
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x,[-1,28,28,1])


########设计卷积神经网络########
# 第一层卷积
# 卷积核尺寸为5*5,1个颜色通道,32个不同的卷积核
W_conv1 = weight_variable([5, 5, 1, 32])
# 用conv2d函数进行卷积操作,加上偏置
b_conv1 = bias_variable([32])
# 把x_image和权值向量进行卷积,加上偏置项,然后应用ReLU激活函数,
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
# 对卷积的输出结果进行池化操作
h_pool1 = max_pool_2x2(h_conv1)

# 第二层卷积(和第一层大致相同,卷积核为64,这一层卷积会提取64种特征)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# 全连接层。隐含节点数1024。使用ReLU激活函数
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# 为了防止过拟合,在输出层之前加Dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 输出层。添加一个softmax层,就像softmax regression一样。得到概率输出。
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)


########模型训练设置########
# 定义loss function为cross entropy,优化器使用Adam,并给予一个比较小的学习速率1e-4
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 定义评测准确率的操作
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


########开始训练过程########
# 初始化所有参数
tf.global_variables_initializer().run()

# 训练(设置训练时Dropout的kepp_prob比率为0.5。mini-batch为50,进行2000次迭代训练,参与训练样本5万)
# 其中每进行100次训练,对准确率进行一次评测keep_prob设置为1,用以实时监测模型的性能
for i in range(1000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
  print "-->step %d, training accuracy %.4f"%(i, train_accuracy)
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 全部训练完成之后,在最终测试集上进行全面测试,得到整体的分类准确率
print "卷积神经网络在MNIST数据集正确率: %g"%accuracy.eval(feed_dict={
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

TensorFlow实现卷积神经网络

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python import用法以及与from...import的区别
May 28 Python
Django中处理出错页面的方法
Jul 15 Python
浅析python中的迭代与迭代对象
Oct 08 Python
用Python逐行分析文件方法
Jan 28 Python
详解Python字典的操作
Mar 04 Python
从0开始的Python学习014面向对象编程(推荐)
Apr 02 Python
python求最大值,不使用内置函数的实现方法
Jul 09 Python
Python实现打印实心和空心菱形
Nov 23 Python
用sleep间隔进行python反爬虫的实例讲解
Nov 30 Python
Python爬虫之Selenium设置元素等待的方法
Dec 04 Python
Python用access判断文件是否被占用的实例方法
Dec 17 Python
利用Python过滤相似文本的简单方法示例
Feb 03 Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
You might like
PHP Ajax实现页面无刷新发表评论
2007/01/02 PHP
php 编写安全的代码时容易犯的错误小结
2010/05/20 PHP
PHP类的反射用法实例
2014/11/03 PHP
php+jQuery+Ajax实现点赞效果的方法(附源码下载)
2020/07/21 PHP
LINUX下PHP程序实现WORD文件转化为PDF文件的方法
2016/05/13 PHP
js 创建快捷方式的代码(fso)
2010/11/19 Javascript
javascript中为某个元素指定事件的三种方式
2014/08/07 Javascript
jquery幻灯片插件bxslider样式改进实例
2014/10/15 Javascript
javascript操作数组详解
2014/12/17 Javascript
javascript实现日期按月份加减
2015/05/15 Javascript
JS学习之表格的排序简单实例
2016/05/16 Javascript
D3.js实现直方图的方法详解
2016/09/25 Javascript
基于JavaScript实现窗口拖动效果
2017/01/18 Javascript
jQuery判断邮箱格式对错实例代码讲解
2017/04/12 jQuery
vue2+el-menu实现路由跳转及当前项的设置方法实例
2017/11/07 Javascript
vue-cli 引入、配置axios的方法
2018/05/08 Javascript
JavaScript代码调试方法实例小结
2019/01/05 Javascript
element跨分页操作选择详解
2020/06/29 Javascript
在vue中实现给每个页面顶部设置title
2020/07/29 Javascript
vue实现tab栏点击高亮效果
2020/08/19 Javascript
Python图算法实例分析
2016/08/13 Python
Python实现的爬虫功能代码
2017/06/24 Python
python 3.6.4 安装配置方法图文教程
2018/09/18 Python
Pycharm创建项目时如何自动添加头部信息
2019/11/14 Python
python随机模块random使用方法详解
2020/02/14 Python
日本无添加化妆品:HABA
2016/08/18 全球购物
巴西24小时在线药房:Drogasil
2020/06/20 全球购物
经济系大学生求职信
2013/10/01 职场文书
五一服装活动方案
2014/01/11 职场文书
单位工程竣工验收方案
2014/03/16 职场文书
低碳环保倡议书
2014/04/14 职场文书
班级年度安全计划书
2014/05/01 职场文书
建筑学专业自荐书
2014/07/09 职场文书
MySQL中in和exists区别详解
2021/06/03 MySQL
MySQL中的隐藏列的具体查看
2021/09/04 MySQL
Java 死锁解决方案
2022/05/11 Java/Android