TensorFlow实现卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了TensorFlow实现卷积神经网络的具体代码,供大家参考,具体内容如下

代码(源代码都有详细的注释)和数据集可以在github下载:

# -*- coding: utf-8 -*-
'''卷积神经网络测试MNIST数据'''

#########导入MNIST数据########
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 创建默认InteractiveSession
sess = tf.InteractiveSession()


#########卷积网络会有很多的权重和偏置需要创建,先定义好初始化函数以便复用########
# 给权重制造一些随机噪声打破完全对称(比如截断的正态分布噪声,标准差设为0.1)
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)
# 因为我们要使用ReLU,也给偏置增加一些小的正值(0.1)用来避免死亡节点(dead neurons)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)


########卷积层、池化层接下来重复使用的,分别定义创建函数########
# tf.nn.conv2d是TensorFlow中的2维卷积函数
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# 使用2*2的最大池化
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')


########正式设计卷积神经网络之前先定义placeholder########
# x是特征,y_是真实label。将图片数据从1D转为2D。使用tensor的变形函数tf.reshape
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x,[-1,28,28,1])


########设计卷积神经网络########
# 第一层卷积
# 卷积核尺寸为5*5,1个颜色通道,32个不同的卷积核
W_conv1 = weight_variable([5, 5, 1, 32])
# 用conv2d函数进行卷积操作,加上偏置
b_conv1 = bias_variable([32])
# 把x_image和权值向量进行卷积,加上偏置项,然后应用ReLU激活函数,
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
# 对卷积的输出结果进行池化操作
h_pool1 = max_pool_2x2(h_conv1)

# 第二层卷积(和第一层大致相同,卷积核为64,这一层卷积会提取64种特征)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# 全连接层。隐含节点数1024。使用ReLU激活函数
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# 为了防止过拟合,在输出层之前加Dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 输出层。添加一个softmax层,就像softmax regression一样。得到概率输出。
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)


########模型训练设置########
# 定义loss function为cross entropy,优化器使用Adam,并给予一个比较小的学习速率1e-4
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 定义评测准确率的操作
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


########开始训练过程########
# 初始化所有参数
tf.global_variables_initializer().run()

# 训练(设置训练时Dropout的kepp_prob比率为0.5。mini-batch为50,进行2000次迭代训练,参与训练样本5万)
# 其中每进行100次训练,对准确率进行一次评测keep_prob设置为1,用以实时监测模型的性能
for i in range(1000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
  print "-->step %d, training accuracy %.4f"%(i, train_accuracy)
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 全部训练完成之后,在最终测试集上进行全面测试,得到整体的分类准确率
print "卷积神经网络在MNIST数据集正确率: %g"%accuracy.eval(feed_dict={
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

TensorFlow实现卷积神经网络

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入解析Python中的变量和赋值运算符
Oct 12 Python
Python中的异常处理相关语句基础学习笔记
Jul 11 Python
Python编写一个闹钟功能
Jul 11 Python
python利用正则表达式排除集合中字符的功能示例
Oct 10 Python
使用Python机器学习降低静态日志噪声
Sep 29 Python
pandas修改DataFrame列名的实现方法
Feb 22 Python
Python计算时间间隔(精确到微妙)的代码实例
Feb 26 Python
Python matplotlib以日期为x轴作图代码实例
Nov 22 Python
Python中six模块基础用法
Dec 08 Python
python 字符串的驻留机制及优缺点
Jun 19 Python
如何基于Python实现word文档重新排版
Sep 29 Python
详细总结Python常见的安全问题
May 21 Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
You might like
239军机修复记
2021/03/02 无线电
解析使用substr截取UTF-8中文字符串出现乱码的问题
2013/06/20 PHP
javascript 全等号运算符使用说明
2010/05/31 Javascript
深入理解JavaScript系列(9) 根本没有“JSON对象”这回事!
2012/01/15 Javascript
jquery 操作日期、星期、元素的追加的实现代码
2012/02/07 Javascript
Js制作简单弹出层DIV在页面居中 中间显示遮罩的具体方法
2013/08/08 Javascript
Javascript浮点数乘积运算出现多位小数的解决方法
2014/02/17 Javascript
jquery实现的省市区三级联动
2015/04/02 Javascript
深入浅析JS Function()构造函数
2016/08/22 Javascript
js实现导航栏中英文切换效果
2017/01/16 Javascript
详解Angualr 组件间通信
2017/01/21 Javascript
nodejs中模块定义实例详解
2017/03/18 NodeJs
EasyUI中的dataGrid的行内编辑
2017/06/22 Javascript
微信小程序request请求后台接口php的实例详解
2017/09/20 Javascript
微信小程序返回多级页面的实现方法
2017/10/27 Javascript
浅谈React中的元素、组件、实例和节点
2018/02/27 Javascript
了解ESlint和其相关操作小结
2018/05/21 Javascript
vue  自定义组件实现通讯录功能
2018/09/30 Javascript
bootstrap下拉分页样式 带跳转页码
2018/12/29 Javascript
详解Vue基于vue-quill-editor富文本编辑器使用心得
2019/01/03 Javascript
vue中多个倒计时实现代码实例
2019/03/27 Javascript
解决vue项目中遇到 Cannot find module ‘chalk‘ 报错的问题
2020/11/05 Javascript
JS实现简易日历效果
2021/01/25 Javascript
举例详解Python中threading模块的几个常用方法
2015/06/18 Python
python实现kMeans算法
2017/12/21 Python
Python之列表实现栈的工作功能
2019/01/28 Python
python各层级目录下import方法代码实例
2020/01/20 Python
销售员求职个人的自我评价
2014/02/19 职场文书
小学生开学感言
2014/02/28 职场文书
民政局个人整改措施
2014/09/24 职场文书
2014年单位法制宣传日活动总结
2014/11/01 职场文书
一年级班主任工作总结2014
2014/11/08 职场文书
人事行政助理岗位职责
2015/04/11 职场文书
2016新年晚会开场白
2015/12/03 职场文书
2016大学生毕业实习心得体会
2016/01/23 职场文书
用Python提取PDF表格的方法
2021/04/11 Python