用tensorflow搭建CNN的方法


Posted in Python onMarch 05, 2018

CNN(Convolutional Neural Networks) 卷积神经网络简单讲就是把一个图片的数据传递给CNN,原涂层是由RGB组成,然后CNN把它的厚度加厚,长宽变小,每做一层都这样被拉长,最后形成一个分类器

用tensorflow搭建CNN的方法

在 CNN 中有几个重要的概念:

  1. stride
  2. padding
  3. pooling

stride,就是每跨多少步抽取信息。每一块抽取一部分信息,长宽就缩减,但是厚度增加。抽取的各个小块儿,再把它们合并起来,就变成一个压缩后的立方体。

padding,抽取的方式有两种,一种是抽取后的长和宽缩减,另一种是抽取后的长和宽和原来的一样。

pooling,就是当跨步比较大的时候,它会漏掉一些重要的信息,为了解决这样的问题,就加上一层叫pooling,事先把这些必要的信息存储起来,然后再变成压缩后的层

利用tensorflow搭建CNN,也就是卷积神经网络是一件很简单的事情,笔者按照官方教程中使用MNIST手写数字识别为例展开代码,整个程序也基本与官方例程一致,不过在比较容易迷惑的地方加入了注释,有一定的机器学习或者卷积神经网络制式的人都应该可以迅速领会到代码的含义。

#encoding=utf-8 
import tensorflow as tf  
import numpy as np  
from tensorflow.examples.tutorials.mnist import input_data  
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)  
def weight_variable(shape): 
  initial = tf.truncated_normal(shape,stddev=0.1) #截断正态分布,此函数原型为尺寸、均值、标准差 
  return tf.Variable(initial) 
def bias_variable(shape): 
  initial = tf.constant(0.1,shape=shape) 
  return tf.Variable(initial) 
def conv2d(x,W): 
  return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME') # strides第0位和第3为一定为1,剩下的是卷积的横向和纵向步长 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x,ksize = [1,2,2,1],strides=[1,2,2,1],padding='SAME')# 参数同上,ksize是池化块的大小 
 
x = tf.placeholder("float", shape=[None, 784]) 
y_ = tf.placeholder("float", shape=[None, 10]) 
 
# 图像转化为一个四维张量,第一个参数代表样本数量,-1表示不定,第二三参数代表图像尺寸,最后一个参数代表图像通道数 
x_image = tf.reshape(x,[-1,28,28,1]) 
 
# 第一层卷积加池化 
w_conv1 = weight_variable([5,5,1,32]) # 第一二参数值得卷积核尺寸大小,即patch,第三个参数是图像通道数,第四个参数是卷积核的数目,代表会出现多少个卷积特征 
b_conv1 = bias_variable([32]) 
 
h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
# 第二层卷积加池化  
w_conv2 = weight_variable([5,5,32,64]) # 多通道卷积,卷积出64个特征 
b_conv2 = bias_variable([64]) 
 
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
# 原图像尺寸28*28,第一轮图像缩小为14*14,共有32张,第二轮后图像缩小为7*7,共有64张 
 
w_fc1 = weight_variable([7*7*64,1024]) 
b_fc1 = bias_variable([1024]) 
 
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64]) # 展开,第一个参数为样本数量,-1未知 
f_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1) 
 
# dropout操作,减少过拟合 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_drop = tf.nn.dropout(f_fc1,keep_prob) 
 
w_fc2 = weight_variable([1024,10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2) 
 
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) # 定义交叉熵为loss函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) # 调用优化器优化 
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 
 
sess = tf.InteractiveSession() 
sess.run(tf.initialize_all_variables()) 
for i in range(2000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0}) 
  print "step %d, training accuracy %g"%(i, train_accuracy) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 
 
print "test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images[0:500], y_: mnist.test.labels[0:500], keep_prob: 1.0})

在程序中主要注意这么几点:

1、维度问题,由于我们tensorflow基于的是张量这样一个概念,张量其实就是维度扩展的矩阵,因此维度特别重要,而且维度也是很容易使人迷惑的地方。

2、卷积问题,卷积核不只是二维的,多通道卷积时卷积核就是三维的

3、最后进行检验的时候,如果一次性加载出所有的验证集,出现了内存爆掉的情况,由于是使用的是云端的服务器,可能内存小一些,如果内存够用可以直接全部加载上看结果

4、这个程序原始版本迭代次数设置了20000次,这个次数大约要训练数个小时(在不使用GPU的情况下),这个次数可以按照要求更改。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python实现QQ游戏大家来找茬辅助工具
Sep 14 Python
python运行时间的几种方法
Jun 17 Python
Python信息抽取之乱码解决办法
Jun 29 Python
python绘制简单折线图代码示例
Dec 19 Python
python实现数据预处理之填充缺失值的示例
Dec 22 Python
详解重置Django migration的常见方式
Feb 15 Python
Python3.4解释器用法简单示例
Mar 22 Python
python 使用shutil复制图片的例子
Dec 13 Python
解决python 虚拟环境删除包无法加载的问题
Jul 13 Python
基于Python pyecharts实现多种图例代码解析
Aug 10 Python
通过代码实例解析Pytest运行流程
Aug 20 Python
Python连续赋值需要注意的一些问题
Jun 03 Python
利用TensorFlow训练简单的二分类神经网络模型的方法
Mar 05 #Python
python使用Pycharm创建一个Django项目
Mar 05 #Python
python爬虫基本知识
Mar 05 #Python
用tensorflow构建线性回归模型的示例代码
Mar 05 #Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
You might like
ThinkPHP水印功能实现修复PNG透明水印并增加JPEG图片质量可调整
2014/11/05 PHP
php实现留言板功能
2017/03/05 PHP
php 使用curl模拟ip和来源进行访问的实现方法
2017/05/02 PHP
JSON 入门指南 想了解json的朋友可以看下
2009/08/26 Javascript
Dom 结点创建 基础知识
2011/10/01 Javascript
页面调用单个swf文件,嵌套出多个方法。
2011/11/21 Javascript
基于jquery的textarea发布框限制文字字数输入(添加中文识别)
2012/02/16 Javascript
两种常用的javascript数组去重方法思路及代码
2013/03/26 Javascript
MultiSelect左右选择控件的设计与实现介绍
2013/06/08 Javascript
jquery ui dialog实现弹窗特效的思路及代码
2013/08/03 Javascript
点击进行复制的JS代码实例
2013/08/23 Javascript
js类式继承的具体实现方法
2013/12/31 Javascript
网页中表单按回车就自动提交的问题的解决方案
2014/11/03 Javascript
js实现仿百度风云榜可重复多次调用的TAB切换选项卡效果
2015/08/31 Javascript
Bootstrap栅格系统简单实现代码
2017/03/06 Javascript
基于Bootstrap框架实现图片切换
2017/03/10 Javascript
nodejs利用ajax实现网页无刷新上传图片实例代码
2017/06/06 NodeJs
vue.js学习之UI组件开发教程
2017/07/03 Javascript
Node.js 使用流实现读写同步边读边写功能
2017/09/11 Javascript
浅谈Vue数据响应思路之数组
2018/11/06 Javascript
JavaScript查看代码运行效率console.time()与console.timeEnd()用法
2019/01/18 Javascript
一篇文章弄懂javascript中的执行栈与执行上下文
2019/08/09 Javascript
世界上最短的数字判断js代码
2019/09/09 Javascript
js之切换全屏和退出全屏实现代码实例
2019/09/09 Javascript
JS实现简单移动端鼠标拖拽
2020/07/23 Javascript
[06:25]第二届DOTA2亚洲邀请赛主赛事第二天比赛集锦.mp4
2017/04/03 DOTA
解读! Python在人工智能中的作用
2017/11/14 Python
python flask解析json数据不完整的解决方法
2019/05/26 Python
ZABBIX3.2使用python脚本实现监控报表的方法
2019/07/02 Python
python networkx 根据图的权重画图实现
2019/07/10 Python
python中有关时间日期格式转换问题
2019/12/25 Python
Python Pandas list列表数据列拆分成多行的方法实现
2020/12/14 Python
python自动生成sql语句的脚本
2021/02/24 Python
2015年高校教师个人工作总结
2015/05/25 职场文书
2015年为民办实事工作总结
2015/05/26 职场文书
Django使用channels + websocket打造在线聊天室
2021/05/20 Python