TensorFlow实现卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了TensorFlow实现卷积神经网络的具体代码,供大家参考,具体内容如下

代码(源代码都有详细的注释)和数据集可以在github下载:

# -*- coding: utf-8 -*-
'''卷积神经网络测试MNIST数据'''

#########导入MNIST数据########
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 创建默认InteractiveSession
sess = tf.InteractiveSession()


#########卷积网络会有很多的权重和偏置需要创建,先定义好初始化函数以便复用########
# 给权重制造一些随机噪声打破完全对称(比如截断的正态分布噪声,标准差设为0.1)
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)
# 因为我们要使用ReLU,也给偏置增加一些小的正值(0.1)用来避免死亡节点(dead neurons)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)


########卷积层、池化层接下来重复使用的,分别定义创建函数########
# tf.nn.conv2d是TensorFlow中的2维卷积函数
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# 使用2*2的最大池化
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')


########正式设计卷积神经网络之前先定义placeholder########
# x是特征,y_是真实label。将图片数据从1D转为2D。使用tensor的变形函数tf.reshape
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x,[-1,28,28,1])


########设计卷积神经网络########
# 第一层卷积
# 卷积核尺寸为5*5,1个颜色通道,32个不同的卷积核
W_conv1 = weight_variable([5, 5, 1, 32])
# 用conv2d函数进行卷积操作,加上偏置
b_conv1 = bias_variable([32])
# 把x_image和权值向量进行卷积,加上偏置项,然后应用ReLU激活函数,
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
# 对卷积的输出结果进行池化操作
h_pool1 = max_pool_2x2(h_conv1)

# 第二层卷积(和第一层大致相同,卷积核为64,这一层卷积会提取64种特征)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# 全连接层。隐含节点数1024。使用ReLU激活函数
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# 为了防止过拟合,在输出层之前加Dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 输出层。添加一个softmax层,就像softmax regression一样。得到概率输出。
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)


########模型训练设置########
# 定义loss function为cross entropy,优化器使用Adam,并给予一个比较小的学习速率1e-4
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 定义评测准确率的操作
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


########开始训练过程########
# 初始化所有参数
tf.global_variables_initializer().run()

# 训练(设置训练时Dropout的kepp_prob比率为0.5。mini-batch为50,进行2000次迭代训练,参与训练样本5万)
# 其中每进行100次训练,对准确率进行一次评测keep_prob设置为1,用以实时监测模型的性能
for i in range(1000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
  print "-->step %d, training accuracy %.4f"%(i, train_accuracy)
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 全部训练完成之后,在最终测试集上进行全面测试,得到整体的分类准确率
print "卷积神经网络在MNIST数据集正确率: %g"%accuracy.eval(feed_dict={
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

TensorFlow实现卷积神经网络

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现模拟按键,自动翻页看u17漫画
Mar 17 Python
举例讲解Python设计模式编程中的访问者与观察者模式
Jan 26 Python
Python中使用platform模块获取系统信息的用法教程
Jul 08 Python
Python中__init__.py文件的作用详解
Sep 18 Python
Python实现利用最大公约数求三个正整数的最小公倍数示例
Sep 30 Python
Python pycharm 同时加载多个项目的方法
Jan 17 Python
基于python的itchat库实现微信聊天机器人(推荐)
Oct 29 Python
python对Excel按条件进行内容补充(推荐)
Nov 24 Python
Python urllib3软件包的使用说明
Nov 18 Python
Python 实现二叉查找树的示例代码
Dec 21 Python
anaconda升级sklearn版本的实现方法
Feb 22 Python
Python&Matlab实现灰狼优化算法的示例代码
Mar 21 Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
You might like
PHP中获取文件扩展名的N种方法小结
2012/02/27 PHP
header跳转和include包含问题详解
2012/09/08 PHP
php include和require的区别深入解析
2013/06/17 PHP
PHP生成网站桌面快捷方式代码分享
2014/10/11 PHP
php curl抓取网页的介绍和推广及使用CURL抓取淘宝页面集成方法
2015/11/30 PHP
Yii2数据库操作常用方法小结
2017/05/04 PHP
js动态创建、删除表格示例代码
2013/08/07 Javascript
用JavaScript计算在UTF-8下存储字符串占用字节数
2013/08/08 Javascript
extjs每个组件要设置唯一的ID否则会出错
2014/06/15 Javascript
JS+CSS实现类似QQ好友及黑名单效果的树型菜单
2015/09/22 Javascript
Bootstrap滚动监听(Scrollspy)插件详解
2016/04/26 Javascript
js Date()日期函数浏览器兼容问题解决方法
2017/09/12 Javascript
详解如何使用webpack在vue项目中写jsx语法
2017/11/08 Javascript
关于vue单文件中引用路径的处理方法
2018/01/08 Javascript
详解如何在vue项目中使用eslint+prettier格式化代码
2018/11/10 Javascript
[jQuery] 事件和动画详解
2019/03/05 jQuery
js+canvas实现纸牌游戏
2020/03/16 Javascript
理解Proxy及使用Proxy实现vue数据双向绑定操作
2020/07/18 Javascript
JavaScript位置参数实现原理及过程解析
2020/09/14 Javascript
python selenium 对浏览器标签页进行关闭和切换的方法
2018/05/21 Python
Python使用Selenium模块实现模拟浏览器抓取淘宝商品美食信息功能示例
2018/07/18 Python
djang常用查询SQL语句的使用代码
2019/02/15 Python
简单的Python调度器Schedule详解
2019/08/30 Python
python 实现线程之间的通信示例
2020/02/14 Python
CSS类名支持中文命名的示例
2014/04/04 HTML / CSS
Brora官网:英国领先的羊绒服装品牌
2019/08/28 全球购物
Web Service面试题:如何搭建Axis2的开发环境
2012/06/20 面试题
行政经理岗位职责
2013/11/09 职场文书
毕业生求职自荐信怎么写
2014/01/08 职场文书
暑期社会实践感言
2014/02/25 职场文书
机房搬迁方案
2014/05/01 职场文书
学习雷锋标语
2014/06/25 职场文书
观看《周恩来的四个昼夜》思想汇报
2014/09/12 职场文书
2014年超市员工工作总结
2014/11/18 职场文书
拉贝日记观后感
2015/06/05 职场文书
python之np.argmax()及对axis=0或者1的理解
2021/06/02 Python