tensorflow学习笔记之mnist的卷积神经网络实例


Posted in Python onApril 15, 2018

mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的。但是CNN层数要多一些,网络模型需要自己来构建。

程序比较复杂,我就分成几个部分来叙述。

首先,下载并加载数据:

import tensorflow as tf 
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)   #下载并加载mnist数据
x = tf.placeholder(tf.float32, [None, 784])            #输入的数据占位符
y_actual = tf.placeholder(tf.float32, shape=[None, 10])      #输入的标签占位符

定义四个函数,分别用于初始化权值W,初始化偏置项b, 构建卷积层和构建池化层。

#定义一个函数,用于初始化所有的权值 W
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)

#定义一个函数,用于初始化所有的偏置项 b
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)
 
#定义一个函数,用于构建卷积层
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

#定义一个函数,用于构建池化层
def max_pool(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

接下来构建网络。整个网络由两个卷积层(包含激活层和池化层),一个全连接层,一个dropout层和一个softmax层组成。

#构建网络
x_image = tf.reshape(x, [-1,28,28,1])     #转换输入数据shape,以便于用于网络中
W_conv1 = weight_variable([5, 5, 1, 32])   
b_conv1 = bias_variable([32])    
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)   #第一个卷积层
h_pool1 = max_pool(h_conv1)                 #第一个池化层

W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)   #第二个卷积层
h_pool2 = max_pool(h_conv2)                  #第二个池化层

W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])       #reshape成向量
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)  #第一个全连接层

keep_prob = tf.placeholder("float") 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)         #dropout层

W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_predict=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)  #softmax层

网络构建好后,就可以开始训练了。

cross_entropy = -tf.reduce_sum(y_actual*tf.log(y_predict))   #交叉熵
train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy)  #梯度下降法
correct_prediction = tf.equal(tf.argmax(y_predict,1), tf.argmax(y_actual,1))  
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))         #精确度计算
sess=tf.InteractiveSession()             
sess.run(tf.initialize_all_variables())
for i in range(20000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:         #训练100次,验证一次
  train_acc = accuracy.eval(feed_dict={x:batch[0], y_actual: batch[1], keep_prob: 1.0})
  print 'step %d, training accuracy %g'%(i,train_acc)
  train_step.run(feed_dict={x: batch[0], y_actual: batch[1], keep_prob: 0.5})

test_acc=accuracy.eval(feed_dict={x: mnist.test.images, y_actual: mnist.test.labels, keep_prob: 1.0})
print "test accuracy %g"%test_acc

Tensorflow依赖于一个高效的C++后端来进行计算。与后端的这个连接叫做session。一般而言,使用TensorFlow程序的流程是先创建一个图,然后在session中启动它。

这里,我们使用更加方便的InteractiveSession类。通过它,你可以更加灵活地构建你的代码。它能让你在运行图的时候,插入一些计算图,这些计算图是由某些操作(operations)构成的。这对于工作在交互式环境中的人们来说非常便利,比如使用IPython。

训练20000次后,再进行测试,测试精度可以达到99%。

完整代码:

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 8 15:29:48 2016

@author: root
"""
import tensorflow as tf 
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)   #下载并加载mnist数据
x = tf.placeholder(tf.float32, [None, 784])            #输入的数据占位符
y_actual = tf.placeholder(tf.float32, shape=[None, 10])      #输入的标签占位符

#定义一个函数,用于初始化所有的权值 W
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)

#定义一个函数,用于初始化所有的偏置项 b
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)
 
#定义一个函数,用于构建卷积层
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

#定义一个函数,用于构建池化层
def max_pool(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

#构建网络
x_image = tf.reshape(x, [-1,28,28,1])     #转换输入数据shape,以便于用于网络中
W_conv1 = weight_variable([5, 5, 1, 32])   
b_conv1 = bias_variable([32])    
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)   #第一个卷积层
h_pool1 = max_pool(h_conv1)                 #第一个池化层

W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)   #第二个卷积层
h_pool2 = max_pool(h_conv2)                  #第二个池化层

W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])       #reshape成向量
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)  #第一个全连接层

keep_prob = tf.placeholder("float") 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)         #dropout层

W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_predict=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)  #softmax层

cross_entropy = -tf.reduce_sum(y_actual*tf.log(y_predict))   #交叉熵
train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy)  #梯度下降法
correct_prediction = tf.equal(tf.argmax(y_predict,1), tf.argmax(y_actual,1))  
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))         #精确度计算
sess=tf.InteractiveSession()             
sess.run(tf.initialize_all_variables())
for i in range(20000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:         #训练100次,验证一次
  train_acc = accuracy.eval(feed_dict={x:batch[0], y_actual: batch[1], keep_prob: 1.0})
  print('step',i,'training accuracy',train_acc)
  train_step.run(feed_dict={x: batch[0], y_actual: batch[1], keep_prob: 0.5})

test_acc=accuracy.eval(feed_dict={x: mnist.test.images, y_actual: mnist.test.labels, keep_prob: 1.0})
print("test accuracy",test_acc)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 使用get_argument获取url query参数
Apr 28 Python
Python模拟登陆实现代码
Jun 14 Python
Pycharm编辑器技巧之自动导入模块详解
Jul 18 Python
python os.path模块常用方法实例详解
Sep 16 Python
Python文件监听工具pyinotify与watchdog实例
Oct 15 Python
对Python中DataFrame选择某列值为XX的行实例详解
Jan 29 Python
python 生成器和迭代器的原理解析
Oct 12 Python
Python中使用threading.Event协调线程的运行详解
May 02 Python
python实现简单的五子棋游戏
Sep 01 Python
Python中qutip用法示例详解
Oct 02 Python
用Python监控你的朋友都在浏览哪些网站?
May 27 Python
python flask开发的简单基金查询工具
Jun 02 Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
You might like
PHP5 字符串处理函数大全
2010/03/23 PHP
浏览器预览PHP文件时顶部出现空白影响布局分析原因及解决办法
2013/01/11 PHP
thinkphp模板的包含与渲染实例分析
2014/11/26 PHP
简单通用的JS滑动门代码
2008/12/19 Javascript
jquery 操作表格实现代码(多种操作打包)
2011/03/20 Javascript
jquerymobile checkbox及时刷新才能获取其准确值
2012/04/14 Javascript
把字符串按照特定的字母顺序进行排序的js代码
2014/01/28 Javascript
jQuery弹出框代码封装DialogHelper
2015/01/30 Javascript
onmouseover事件和onmouseout事件全面理解
2016/08/15 Javascript
react-native ListView下拉刷新上拉加载实现代码
2017/08/03 Javascript
利用jsonp与代理服务器方案解决跨域问题
2017/09/14 Javascript
微信小程序 配置顶部导航条标题颜色的实现方法
2017/09/20 Javascript
基于JavaScript中字符串的match与replace方法(详解)
2017/12/04 Javascript
vue-cli2打包前和打包后的css前缀不一致的问题解决
2018/08/24 Javascript
仿ElementUI实现一个Form表单的实现代码
2019/04/23 Javascript
关于layui toolbar和template的结合使用方法
2019/09/19 Javascript
js实现简易计算器功能
2019/10/18 Javascript
微信浏览器左上角返回按钮监听的实现
2020/03/04 Javascript
详解Vue中Axios封装API接口的思路及方法
2020/10/10 Javascript
antd中table展开行默认展示,且不需要前边的加号操作
2020/11/02 Javascript
[02:38]DOTA2英雄基础教程 噬魂鬼
2014/01/03 DOTA
详解Python中expandtabs()方法的使用
2015/05/18 Python
python 将json数据提取转化为txt的方法
2018/10/26 Python
python中时间、日期、时间戳的转换的实现方法
2019/07/06 Python
关于Pytorch的MNIST数据集的预处理详解
2020/01/10 Python
Python GUI编程学习笔记之tkinter界面布局显示详解
2020/03/30 Python
python线性插值解析
2020/07/05 Python
Python中random模块常用方法的使用教程
2020/10/04 Python
解决Pycharm 运行后没有输出的问题
2021/02/05 Python
详解如何将 Canvas 绘制过程转为视频
2021/01/25 HTML / CSS
世界顶级足球门票网站:Live Football Tickets
2017/10/14 全球购物
波兰运动鞋网上商店:e-Sporting
2018/02/16 全球购物
Nayomi官网:沙特阿拉伯王国睡衣和内衣品牌
2020/12/19 全球购物
电子信息专业自荐书
2014/02/04 职场文书
十岁生日答谢词
2015/01/05 职场文书
mysq启动失败问题及场景分析
2021/07/15 MySQL