TensorFlow实现卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了TensorFlow实现卷积神经网络的具体代码,供大家参考,具体内容如下

代码(源代码都有详细的注释)和数据集可以在github下载:

# -*- coding: utf-8 -*-
'''卷积神经网络测试MNIST数据'''

#########导入MNIST数据########
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 创建默认InteractiveSession
sess = tf.InteractiveSession()


#########卷积网络会有很多的权重和偏置需要创建,先定义好初始化函数以便复用########
# 给权重制造一些随机噪声打破完全对称(比如截断的正态分布噪声,标准差设为0.1)
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)
# 因为我们要使用ReLU,也给偏置增加一些小的正值(0.1)用来避免死亡节点(dead neurons)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)


########卷积层、池化层接下来重复使用的,分别定义创建函数########
# tf.nn.conv2d是TensorFlow中的2维卷积函数
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# 使用2*2的最大池化
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')


########正式设计卷积神经网络之前先定义placeholder########
# x是特征,y_是真实label。将图片数据从1D转为2D。使用tensor的变形函数tf.reshape
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x,[-1,28,28,1])


########设计卷积神经网络########
# 第一层卷积
# 卷积核尺寸为5*5,1个颜色通道,32个不同的卷积核
W_conv1 = weight_variable([5, 5, 1, 32])
# 用conv2d函数进行卷积操作,加上偏置
b_conv1 = bias_variable([32])
# 把x_image和权值向量进行卷积,加上偏置项,然后应用ReLU激活函数,
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
# 对卷积的输出结果进行池化操作
h_pool1 = max_pool_2x2(h_conv1)

# 第二层卷积(和第一层大致相同,卷积核为64,这一层卷积会提取64种特征)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# 全连接层。隐含节点数1024。使用ReLU激活函数
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# 为了防止过拟合,在输出层之前加Dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 输出层。添加一个softmax层,就像softmax regression一样。得到概率输出。
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)


########模型训练设置########
# 定义loss function为cross entropy,优化器使用Adam,并给予一个比较小的学习速率1e-4
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 定义评测准确率的操作
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


########开始训练过程########
# 初始化所有参数
tf.global_variables_initializer().run()

# 训练(设置训练时Dropout的kepp_prob比率为0.5。mini-batch为50,进行2000次迭代训练,参与训练样本5万)
# 其中每进行100次训练,对准确率进行一次评测keep_prob设置为1,用以实时监测模型的性能
for i in range(1000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
  print "-->step %d, training accuracy %.4f"%(i, train_accuracy)
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 全部训练完成之后,在最终测试集上进行全面测试,得到整体的分类准确率
print "卷积神经网络在MNIST数据集正确率: %g"%accuracy.eval(feed_dict={
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

TensorFlow实现卷积神经网络

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中__new__与__init__方法的区别详解
May 04 Python
Python 文件处理注意事项总结
Apr 10 Python
利用Python将时间或时间间隔转为ISO 8601格式方法示例
Sep 05 Python
Python3 max()函数基础用法
Feb 19 Python
浅谈Python小波分析库Pywavelets的一点使用心得
Jul 09 Python
python 使用pygame工具包实现贪吃蛇游戏(多彩版)
Oct 30 Python
使用opencv将视频帧转成图片输出
Dec 10 Python
python 3.8.3 安装配置图文教程
May 21 Python
django haystack实现全文检索的示例代码
Jun 24 Python
解决python 执行shell命令无法获取返回值的问题
Dec 05 Python
Python 可迭代对象 iterable的具体使用
Aug 07 Python
Python Pandas读取Excel日期数据的异常处理方法
Feb 28 Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
You might like
PHP中error_log()函数的使用方法
2015/01/20 PHP
php实现的读取CSV文件函数示例
2017/02/07 PHP
PHP合并两个或多个数组的方法
2019/01/20 PHP
php源码的安装方法和实例
2019/09/26 PHP
JS动态创建Table,Tr,Td并赋值的具体实现
2013/07/05 Javascript
javascript使用location.search的示例
2013/11/05 Javascript
在JS数组特定索引处指定位置插入元素
2014/07/27 Javascript
2014最热门的JavaScript代码高亮插件推荐
2014/11/25 Javascript
使用jquery动态加载js文件的方法
2014/12/24 Javascript
再次谈论Javascript中的this
2016/06/23 Javascript
vue的事件绑定与方法详解
2017/08/16 Javascript
JS实现字符串翻转的方法分析
2018/08/31 Javascript
JavaScript设计模式之命令模式实例分析
2019/01/16 Javascript
详解在React-Native中持久化redux数据
2019/05/22 Javascript
JavaScript的Proxy可以做哪些有意思的事儿
2019/06/15 Javascript
Bootstrap实现省市区三级联动(亲测可用)
2019/07/26 Javascript
JS实现简单省市二级联动
2019/11/27 Javascript
微信小程序indexOf的替换方法(推荐)
2020/01/14 Javascript
Vue 一键清空表单的实现方法
2020/02/07 Javascript
Python易忽视知识点小结
2015/05/25 Python
将pip源更换到国内镜像的详细步骤
2019/04/07 Python
pandas的连接函数concat()函数的具体使用方法
2019/07/09 Python
Python完全识别验证码自动登录实例详解
2019/11/24 Python
Python 下载Bing壁纸的示例
2020/09/29 Python
matplotlib bar()实现多组数据并列柱状图通用简便创建方法
2021/02/24 Python
CSS3实现文字波浪线效果示例代码
2016/11/20 HTML / CSS
canvas生成带二维码海报的踩坑记录
2019/09/11 HTML / CSS
TripAdvisor土耳其网站:全球知名旅行社区,真实旅客评论
2017/04/17 全球购物
波兰香水和化妆品购物网站:Notino.pl
2017/11/07 全球购物
Sneaker Studio波兰:购买运动鞋
2018/04/28 全球购物
涉外经济法专业毕业生推荐信
2013/11/24 职场文书
我的五年职业生涯规划
2014/01/23 职场文书
元旦红领巾广播稿
2014/02/19 职场文书
保护环境的建议书
2014/03/12 职场文书
生物工程专业求职信
2014/09/03 职场文书
2014党的群众路线教育实践活动总结报告
2014/10/31 职场文书