用tensorflow搭建CNN的方法


Posted in Python onMarch 05, 2018

CNN(Convolutional Neural Networks) 卷积神经网络简单讲就是把一个图片的数据传递给CNN,原涂层是由RGB组成,然后CNN把它的厚度加厚,长宽变小,每做一层都这样被拉长,最后形成一个分类器

用tensorflow搭建CNN的方法

在 CNN 中有几个重要的概念:

  1. stride
  2. padding
  3. pooling

stride,就是每跨多少步抽取信息。每一块抽取一部分信息,长宽就缩减,但是厚度增加。抽取的各个小块儿,再把它们合并起来,就变成一个压缩后的立方体。

padding,抽取的方式有两种,一种是抽取后的长和宽缩减,另一种是抽取后的长和宽和原来的一样。

pooling,就是当跨步比较大的时候,它会漏掉一些重要的信息,为了解决这样的问题,就加上一层叫pooling,事先把这些必要的信息存储起来,然后再变成压缩后的层

利用tensorflow搭建CNN,也就是卷积神经网络是一件很简单的事情,笔者按照官方教程中使用MNIST手写数字识别为例展开代码,整个程序也基本与官方例程一致,不过在比较容易迷惑的地方加入了注释,有一定的机器学习或者卷积神经网络制式的人都应该可以迅速领会到代码的含义。

#encoding=utf-8 
import tensorflow as tf  
import numpy as np  
from tensorflow.examples.tutorials.mnist import input_data  
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)  
def weight_variable(shape): 
  initial = tf.truncated_normal(shape,stddev=0.1) #截断正态分布,此函数原型为尺寸、均值、标准差 
  return tf.Variable(initial) 
def bias_variable(shape): 
  initial = tf.constant(0.1,shape=shape) 
  return tf.Variable(initial) 
def conv2d(x,W): 
  return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME') # strides第0位和第3为一定为1,剩下的是卷积的横向和纵向步长 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x,ksize = [1,2,2,1],strides=[1,2,2,1],padding='SAME')# 参数同上,ksize是池化块的大小 
 
x = tf.placeholder("float", shape=[None, 784]) 
y_ = tf.placeholder("float", shape=[None, 10]) 
 
# 图像转化为一个四维张量,第一个参数代表样本数量,-1表示不定,第二三参数代表图像尺寸,最后一个参数代表图像通道数 
x_image = tf.reshape(x,[-1,28,28,1]) 
 
# 第一层卷积加池化 
w_conv1 = weight_variable([5,5,1,32]) # 第一二参数值得卷积核尺寸大小,即patch,第三个参数是图像通道数,第四个参数是卷积核的数目,代表会出现多少个卷积特征 
b_conv1 = bias_variable([32]) 
 
h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
# 第二层卷积加池化  
w_conv2 = weight_variable([5,5,32,64]) # 多通道卷积,卷积出64个特征 
b_conv2 = bias_variable([64]) 
 
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
# 原图像尺寸28*28,第一轮图像缩小为14*14,共有32张,第二轮后图像缩小为7*7,共有64张 
 
w_fc1 = weight_variable([7*7*64,1024]) 
b_fc1 = bias_variable([1024]) 
 
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64]) # 展开,第一个参数为样本数量,-1未知 
f_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1) 
 
# dropout操作,减少过拟合 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_drop = tf.nn.dropout(f_fc1,keep_prob) 
 
w_fc2 = weight_variable([1024,10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2) 
 
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) # 定义交叉熵为loss函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) # 调用优化器优化 
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 
 
sess = tf.InteractiveSession() 
sess.run(tf.initialize_all_variables()) 
for i in range(2000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0}) 
  print "step %d, training accuracy %g"%(i, train_accuracy) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 
 
print "test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images[0:500], y_: mnist.test.labels[0:500], keep_prob: 1.0})

在程序中主要注意这么几点:

1、维度问题,由于我们tensorflow基于的是张量这样一个概念,张量其实就是维度扩展的矩阵,因此维度特别重要,而且维度也是很容易使人迷惑的地方。

2、卷积问题,卷积核不只是二维的,多通道卷积时卷积核就是三维的

3、最后进行检验的时候,如果一次性加载出所有的验证集,出现了内存爆掉的情况,由于是使用的是云端的服务器,可能内存小一些,如果内存够用可以直接全部加载上看结果

4、这个程序原始版本迭代次数设置了20000次,这个次数大约要训练数个小时(在不使用GPU的情况下),这个次数可以按照要求更改。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 创建子进程模块subprocess详解
Apr 08 Python
Python中多线程的创建及基本调用方法
Jul 08 Python
Python实现将一个大文件按段落分隔为多个小文件的简单操作方法
Apr 17 Python
Python 闭包的使用方法
Sep 07 Python
python队列通信:rabbitMQ的使用(实例讲解)
Dec 22 Python
Python数据分析库pandas基本操作方法
Apr 08 Python
Django forms组件的使用教程
Oct 08 Python
Python帮你识破双11的套路
Nov 11 Python
pytorch中的自定义数据处理详解
Jan 06 Python
PyQt5多线程防卡死和多窗口用法的实现
Sep 15 Python
PyTorch中clone()、detach()及相关扩展详解
Dec 09 Python
Python中递归以及递归遍历目录详解
Oct 24 Python
利用TensorFlow训练简单的二分类神经网络模型的方法
Mar 05 #Python
python使用Pycharm创建一个Django项目
Mar 05 #Python
python爬虫基本知识
Mar 05 #Python
用tensorflow构建线性回归模型的示例代码
Mar 05 #Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
You might like
PHP daddslashes 使用方法介绍
2012/10/26 PHP
php防止sql注入示例分析和几种常见攻击正则表达式
2014/01/12 PHP
详解WordPress中用于更新和获取用户选项数据的PHP函数
2016/03/08 PHP
PHP文件系统管理(实例讲解)
2017/09/19 PHP
php多进程中的阻塞与非阻塞操作实例分析
2020/03/04 PHP
在网页里看flash的trace数据的js类
2009/01/10 Javascript
js写的评论分页(还不错)
2013/12/23 Javascript
jQuery实现动画效果的简单实例
2014/01/27 Javascript
Nodejs学习笔记之Stream模块
2015/01/13 NodeJs
javascript实现Table间隔色以及选择高亮(和动态切换数据)的方法
2015/05/14 Javascript
jQuery+ajax+asp.net获取Json值的方法
2016/06/08 Javascript
详解Javascript中prototype属性(推荐)
2016/09/03 Javascript
简单实现jQuery级联菜单
2017/01/09 Javascript
安装vue-cli报错 -4058 的解决方法
2017/10/19 Javascript
9种改善AngularJS性能的方法
2017/11/28 Javascript
ES6/JavaScript使用技巧分享
2017/12/14 Javascript
Angular数据绑定机制原理
2018/04/17 Javascript
vue keep-alive请求数据的方法示例
2018/05/16 Javascript
vue+axios+element ui 实现全局loading加载示例
2018/09/11 Javascript
对python .txt文件读取及数据处理方法总结
2018/04/23 Python
python模块导入的细节详解
2018/12/10 Python
Django数据库类库MySQLdb使用详解
2019/04/28 Python
python文本数据处理学习笔记详解
2019/06/17 Python
Django+zTree构建组织架构树的方法
2019/08/21 Python
Python使用Chrome插件实现爬虫过程图解
2020/06/09 Python
如何向scrapy中的spider传递参数的几种方法
2020/11/18 Python
为什么要优先使用同步代码块而不是同步方法?
2013/01/30 面试题
幼儿园教师国培感言
2014/02/02 职场文书
应届毕业生应聘自荐信范文
2014/02/26 职场文书
厨师个人自我鉴定范文
2014/04/19 职场文书
大学生入党推荐书范文
2014/05/17 职场文书
企业理念标语
2014/06/09 职场文书
领导班子四风对照检查材料思想汇报
2014/09/26 职场文书
2014年学校团委工作总结
2014/12/20 职场文书
我在伊朗长大观后感
2015/06/16 职场文书
《穷人》教学反思
2016/02/19 职场文书