TensorFlow实现卷积神经网络CNN


Posted in Python onMarch 09, 2018

一、卷积神经网络CNN简介

卷积神经网络(ConvolutionalNeuralNetwork,CNN)最初是为解决图像识别等问题设计的,CNN现在的应用已经不限于图像和视频,也可用于时间序列信号,比如音频信号和文本数据等。CNN作为一个深度学习架构被提出的最初诉求是降低对图像数据预处理的要求,避免复杂的特征工程。在卷积神经网络中,第一个卷积层会直接接受图像像素级的输入,每一层卷积(滤波器)都会提取数据中最有效的特征,这种方法可以提取到图像中最基础的特征,而后再进行组合和抽象形成更高阶的特征,因此CNN在理论上具有对图像缩放、平移和旋转的不变性。

卷积神经网络CNN的要点就是局部连接(LocalConnection)、权值共享(WeightsSharing)和池化层(Pooling)中的降采样(Down-Sampling)。其中,局部连接和权值共享降低了参数量,使训练复杂度大大下降并减轻了过拟合。同时权值共享还赋予了卷积网络对平移的容忍性,池化层降采样则进一步降低了输出参数量并赋予模型对轻度形变的容忍性,提高了模型的泛化能力。可以把卷积层卷积操作理解为用少量参数在图像的多个位置上提取相似特征的过程。

更多请参见:深度学习之卷积神经网络CNN

二、TensorFlow代码实现

#!/usr/bin/env python2 
# -*- coding: utf-8 -*- 
""" 
Created on Thu Mar 9 22:01:46 2017 
 
@author: marsjhao 
""" 
 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
def weight_variable(shape): 
 initial = tf.truncated_normal(shape, stddev=0.1) #标准差为0.1的正态分布 
 return tf.Variable(initial) 
 
def bias_variable(shape): 
 initial = tf.constant(0.1, shape=shape) #偏差初始化为0.1 
 return tf.Variable(initial) 
 
def conv2d(x, W): 
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 
       strides=[1, 2, 2, 1], padding='SAME') 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
# -1代表先不考虑输入的图片例子多少这个维度,1是channel的数量 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
keep_prob = tf.placeholder(tf.float32) 
 
# 构建卷积层1 
W_conv1 = weight_variable([5, 5, 1, 32]) # 卷积核5*5,1个channel,32个卷积核,形成32个featuremap 
b_conv1 = bias_variable([32]) # 32个featuremap的偏置 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) # 用relu非线性处理 
h_pool1 = max_pool_2x2(h_conv1) # pooling池化 
 
# 构建卷积层2 
W_conv2 = weight_variable([5, 5, 32, 64]) # 注意这里channel值是32 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
# 构建全连接层1 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool3 = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool3, W_fc1) + b_fc1) 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
# 构建全连接层2 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), 
            reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
correct_prediction = tf.equal(tf.arg_max(y_conv, 1), tf.arg_max(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
tf.global_variables_initializer().run() 
 
for i in range(20001): 
 batch = mnist.train.next_batch(50) 
 if i % 100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], 
             keep_prob: 1.0}) 
  print("step %d, training accuracy %g" %(i, train_accuracy)) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob:0.5}) 
print("test accuracy %g" %accuracy.eval(feed_dict={x: mnist.test.images, 
         y_: mnist.test.labels, keep_prob: 1.0}))

三、代码解读

该代码是用TensorFlow实现一个简单的卷积神经网络,在数据集MNIST上,预期可以实现99.2%左右的准确率。结构上使用两个卷积层和一个全连接层。

首先载入MNIST数据集,采用独热编码,并创建tf.InteractiveSession。然后为后续即将多次使用的部分代码创建函数,包括权重初始化weight_variable、偏置初始化bias_variable、卷积层conv2d、最大池化max_pool_2x2。其中权重初始化的时候要进行含有噪声的非对称初始化,打破完全对称。又由于我们要使用ReLU单元,也需要给偏置bias增加一些小的正值(0.1)用来避免死亡节点(dead neurons)。

构建卷积神经网络之前,先要定义输入的placeholder,特征x和真实标签y_,将1*784格式的特征x转换reshape为28*28的图片格式,又由于只有一个通道且不确定输入样本的数量,故最终尺寸为[-1, 28, 28, 1]。

接下来定义第一个卷积层,首先初始化weights和bias,然后使用conv2d进行卷积操作并加上偏置,随后使用ReLU激活函数进行非线性处理,最后使用最大池化函数对卷积的输出结果进行池化操作。

相同的步骤定义第二个卷积层,不同的地方是卷积核的数量为64,也就是说这一层的卷积会提取64种特征。经过两层不变尺寸的卷积和两次尺寸减半的池化,第二个卷积层后的输出尺寸为7*7*64。将其reshape为长度为7*7*64的1-D向量。经过ReLU后,为了减轻过拟合,使用一个Dropout层,在训练时随机丢弃部分节点的数据减轻过拟合,在预测的时候保留全部数据来追求最好的测试性能。

最后加一个Softmax层,得到最后的预测概率。随后的定义损失函数、优化器、评测准确率不再详细赘述。

训练过程首先进行初始化全部参数,训练时keep_prob比率设置为0.5,评测时设置为1。训练完成后,在最终的测试集上进行全面的测试,得到整体的分类准确率。

经过实验,这个CNN的模型可以得到99.2%的准确率,相比于MLP又有了较大幅度的提高。

四、其他解读补充

1. tf.nn.conv2d(x,W, strides=[1, 1, 1, 1], padding='SAME')

tf.nn.conv2d是TensorFlow的2维卷积函数,x和W都是4-D的tensors。x是输入input shape=[batch,in_height, in_width, in_channels],W是卷积的参数filter / kernel shape=[filter_height, filter_width, in_channels,out_channels]。strides参数是长度为4的1-D参数,代表了卷积核(滑动窗口)移动的步长,其中对于图片strides[0]和strides[3]必须是1,都是1表示不遗漏地划过图片的每一个点。padding参数中SAME代表给边界加上Padding让卷积的输出和输入保持相同的尺寸。

2. tf.nn.max_pool(x,ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

tf.nn.max_pool是TensorFlow中的最大池化函数,x是4-D的输入tensor shape=[batch, height, width, channels],ksize参数表示池化窗口的大小,取一个4维向量,一般是[1, height, width, 1],因为我们不想在batch和channels上做池化,所以这两个维度设为了1,strides与tf.nn.conv2d相同,strides=[1, 2, 2, 1]可以缩小图片尺寸。padding参数也参见tf.nn.conv2d。

Python 相关文章推荐
Python 异常处理实例详解
Mar 12 Python
Python使用爬虫爬取静态网页图片的方法详解
Jun 05 Python
Python代码块批量添加Tab缩进的方法
Jun 25 Python
Python实现朴素贝叶斯分类器的方法详解
Jul 04 Python
在pandas多重索引multiIndex中选定指定索引的行方法
Nov 16 Python
Python设计模式之抽象工厂模式原理与用法详解
Jan 15 Python
python pygame实现五子棋小游戏
Oct 26 Python
OpenCV 边缘检测
Jul 10 Python
python scipy卷积运算的实现方法
Sep 16 Python
基于Python实现大文件分割和命名脚本过程解析
Sep 29 Python
Django表单提交后实现获取相同name的不同value值
May 14 Python
基于python实现操作git过程代码解析
Jul 27 Python
新手常见6种的python报错及解决方法
Mar 09 #Python
Python 函数基础知识汇总
Mar 09 #Python
Python 使用with上下文实现计时功能
Mar 09 #Python
TensorFlow搭建神经网络最佳实践
Mar 09 #Python
TensorFlow实现Batch Normalization
Mar 08 #Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
You might like
PHP strip_tags()去除HTML、XML以及PHP的标签介绍
2014/02/18 PHP
javascript iframe内的函数调用实现方法
2009/07/19 Javascript
Javascript浅谈之this
2013/12/17 Javascript
JavaScript中实现map功能代码分享
2015/06/11 Javascript
在JavaScript的AngularJS库中进行单元测试的方法
2015/06/23 Javascript
jQuery判断多个input file 都不能为空的例子
2015/06/23 Javascript
JavaScript对HTML DOM使用EventListener进行操作
2015/10/21 Javascript
浅析JavaScript声明变量
2015/12/21 Javascript
angularjs创建弹出框实现拖动效果
2020/08/25 Javascript
node.js连接mongoDB数据库 快速搭建自己的web服务
2016/04/17 Javascript
带有定位当前位置的百度地图前端web api实例代码
2016/06/21 Javascript
JavaScript编程中实现对象封装特性的实例讲解
2016/06/24 Javascript
JavaScript制作简易计算器(不用eval)
2017/02/05 Javascript
微信小程序登录态控制深入分析
2017/04/12 Javascript
JavaScript数据类型和变量_动力节点Java学院整理
2017/06/26 Javascript
layui弹出层按钮提交iframe表单的方法
2018/08/20 Javascript
微信小程序云开发实现数据添加、查询和分页
2019/05/17 Javascript
浅析vue-cli3配置webpack-bundle-analyzer插件【推荐】
2019/10/23 Javascript
详解vue中在循环中使用@mouseenter 和 @mouseleave事件闪烁问题解决方法
2020/04/07 Javascript
jQuery实现简单三级联动效果
2020/09/05 jQuery
python操作mysql中文显示乱码的解决方法
2014/10/11 Python
Django学习笔记之Class-Based-View
2017/02/15 Python
Python中的id()函数指的什么
2017/10/17 Python
Python 错误和异常代码详解
2018/01/29 Python
python中的插值 scipy-interp的实现代码
2018/07/23 Python
对Python 3.5拼接列表的新语法详解
2018/11/08 Python
Python中使用filter过滤列表的一个小技巧分享
2020/05/02 Python
keras中的backend.clip用法
2020/05/22 Python
使用Keras实现简单线性回归模型操作
2020/06/12 Python
video结合canvas实现视频在线截图功能
2018/06/25 HTML / CSS
美国最大的香水连锁店官网:Perfumania
2016/08/15 全球购物
美国单身专业人士在线约会网站:EliteSingles
2019/03/19 全球购物
一年级语文教学反思
2014/02/13 职场文书
小学安全教育材料
2014/02/17 职场文书
小学语文教研活动总结
2014/07/01 职场文书
小学班主任事迹材料
2014/12/17 职场文书