tensorflow实现简单的卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了Android九宫格图片展示的具体代码,供大家参考,具体内容如下

一.知识点总结

1.  卷积神经网络出现的初衷是降低对图像的预处理,避免建立复杂的特征工程。因为卷积神经网络在训练的过程中,自己会提取特征。

2.   灵感来自于猫的视觉皮层研究,每一个视觉神经元只会处理一小块区域的视觉图像,即感知野。放到卷积神经网络里就是每一个隐含节点只与设定范围内的像素点相连(设定范围就是卷积核的尺寸),而全连接层是每个像素点与每个隐含节点相连。这种感知野也称之为局部感知。

例如,一张1000*1000的图片,如果隐含层有100*100个隐含节点全连接,则需要1000*1000*100*100+100*100个参数,而如果有10*10的范围局部感知,用同样多的隐含节点,只需要10*10*100*100+100*100个参数。

3.  把卷积的过程称作卷积滤波,除了上面的局部感知,卷积滤波还有一个化简操作——权值共享。即一个卷积滤波中的所有隐含节点与感知图像连接的权值是一样的,这样,上述例子的参数减少为10*10+100*100个了。W的数量等于感知范围的尺寸。

4.  为了抗变形和减小复杂度,卷积层同时还要做激活和池化。激活函数前一章已经弄明白了,池化,相当于降采样,将n*n的像素区域采样为m*m区域,m通常小于n。通常选择最大池化,即选择区域内的最大像素点。 

5.  总结来讲,卷积有三个要点:局部连接、权值共享、池化降采样。一个卷积过程包含三个步骤:卷积滤波、激活、池化。 

6.  卷积滤波中的卷积范围可以用一个词来代替——卷积核,卷积核等同于卷积滤波中的一个隐含节点感知范围。由于权值共享,相当于一个卷积核对整个图像做多次小范围滤波,每滤一次波生成一个小的特征图像,多次滤波后将所有小特征图像组合起来,生成了对整个图像的feature map。通常,一个卷积滤波过程有多个卷积核卷积,生成多张feature map。

所有的feature map都会被池化,然后输入下一层。 

7.  需要训练的权值(参数)的数量只和卷积核尺寸有关,隐含节点(即卷积核要卷积的次数)只和卷积的卷积步长、图像尺寸有关。

个人理解,一个卷积核对整个图像卷积的过程,就像是一个棋子,在整个棋盘上按照步长跳动,每跳动一次,对感知范围内的像素点做一次连接计算。 

8.  CNN在结构上和图像的结构更为接近,都是2D的,因此,早期用在图像上效果很好,但是最近,CNN用于NLP也很热门。

二.程序解析

# coding: utf-8 
 
# In[1]: 
 
from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

补充一下目前三个网络在mnist上的精度分别为:

无隐含层的softmax:91.5%

加入一个全连接隐含层的感知机:98.1%

此cnn:99.07%

和作者的训练结果有细微的差异,可能设备不同吧。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 如何访问外围作用域中的变量
Sep 11 Python
Python爬取三国演义的实现方法
Sep 12 Python
Python编程之Re模块下的函数介绍
Oct 28 Python
Python pandas常用函数详解
Feb 07 Python
python实现微信机器人: 登录微信、消息接收、自动回复功能
Apr 29 Python
Python3+OpenCV2实现图像的几何变换(平移、镜像、缩放、旋转、仿射)
May 13 Python
学习python分支结构
May 17 Python
python 返回一个列表中第二大的数方法
Jul 09 Python
Python将list元素转存为CSV文件的实现
Nov 16 Python
最新pycharm安装教程
Nov 18 Python
Python爬虫之Selenium警告框(弹窗)处理
Dec 04 Python
pytorch 如何把图像数据集进行划分成train,test和val
May 31 Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
You might like
ThinkPHP路由详解
2015/07/27 PHP
CI框架实现优化文件上传及多文件上传的方法
2017/01/04 PHP
php实现QQ小程序发送模板消息功能
2019/09/18 PHP
根据json字符串生成Html的一种方式
2013/01/09 Javascript
自己使用js/jquery写的一个定制对话框控件
2014/05/02 Javascript
告诉你什么是javascript的回调函数
2014/09/04 Javascript
jQuery实现DIV层收缩展开的方法
2015/02/27 Javascript
jQuery实现行文字链接提示效果的方法
2015/03/10 Javascript
AngularJS验证信息框架的封装插件用法【w5cValidator扩展插件】
2016/11/03 Javascript
浅谈JavaScript的计时器对象
2016/12/26 Javascript
vue一个页面实现音乐播放器的示例
2018/02/06 Javascript
为jquery的ajax请求添加超时timeout时间的操作方法
2018/09/04 jQuery
浅谈javascript中的prototype和__proto__的理解
2019/04/07 Javascript
使用layer弹窗提交表单时判断表单是否输入为空的例子
2019/09/26 Javascript
JavaScript实现轮播图效果代码实例
2019/09/28 Javascript
JS动态显示倒计时效果
2019/12/12 Javascript
[40:05]LGD vs Winstrike 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/18 DOTA
python基于mysql实现的简单队列以及跨进程锁实例详解
2014/07/07 Python
Python计算三角函数之asin()方法的使用
2015/05/15 Python
运用TensorFlow进行简单实现线性回归、梯度下降示例
2018/03/05 Python
Python 3.x 安装opencv+opencv_contrib的操作方法
2018/04/02 Python
Python 通过requests实现腾讯新闻抓取爬虫的方法
2019/02/22 Python
Python实现平行坐标图的两种方法小结
2019/07/04 Python
Python空间数据处理之GDAL读写遥感图像
2019/08/01 Python
Django中使用CORS实现跨域请求过程解析
2019/08/05 Python
关于tensorflow softmax函数用法解析
2020/06/30 Python
Python 如何展开嵌套的序列
2020/08/01 Python
python 自定义异常和主动抛出异常(raise)的操作
2020/12/11 Python
个人租房协议书样本
2014/10/01 职场文书
关于成立领导小组的通知
2015/04/23 职场文书
预备党员考察意见范文
2015/06/01 职场文书
《我和小伙伴》教学反思
2016/02/20 职场文书
解析Java异步之call future
2021/06/14 Java/Android
MongoDB 常用的crud操作语句
2021/06/20 MongoDB
如何搭建 MySQL 高可用高性能集群
2021/06/21 MySQL
Go语言基础map用法及示例详解
2021/11/17 Golang