tensorflow实现简单的卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了Android九宫格图片展示的具体代码,供大家参考,具体内容如下

一.知识点总结

1.  卷积神经网络出现的初衷是降低对图像的预处理,避免建立复杂的特征工程。因为卷积神经网络在训练的过程中,自己会提取特征。

2.   灵感来自于猫的视觉皮层研究,每一个视觉神经元只会处理一小块区域的视觉图像,即感知野。放到卷积神经网络里就是每一个隐含节点只与设定范围内的像素点相连(设定范围就是卷积核的尺寸),而全连接层是每个像素点与每个隐含节点相连。这种感知野也称之为局部感知。

例如,一张1000*1000的图片,如果隐含层有100*100个隐含节点全连接,则需要1000*1000*100*100+100*100个参数,而如果有10*10的范围局部感知,用同样多的隐含节点,只需要10*10*100*100+100*100个参数。

3.  把卷积的过程称作卷积滤波,除了上面的局部感知,卷积滤波还有一个化简操作——权值共享。即一个卷积滤波中的所有隐含节点与感知图像连接的权值是一样的,这样,上述例子的参数减少为10*10+100*100个了。W的数量等于感知范围的尺寸。

4.  为了抗变形和减小复杂度,卷积层同时还要做激活和池化。激活函数前一章已经弄明白了,池化,相当于降采样,将n*n的像素区域采样为m*m区域,m通常小于n。通常选择最大池化,即选择区域内的最大像素点。 

5.  总结来讲,卷积有三个要点:局部连接、权值共享、池化降采样。一个卷积过程包含三个步骤:卷积滤波、激活、池化。 

6.  卷积滤波中的卷积范围可以用一个词来代替——卷积核,卷积核等同于卷积滤波中的一个隐含节点感知范围。由于权值共享,相当于一个卷积核对整个图像做多次小范围滤波,每滤一次波生成一个小的特征图像,多次滤波后将所有小特征图像组合起来,生成了对整个图像的feature map。通常,一个卷积滤波过程有多个卷积核卷积,生成多张feature map。

所有的feature map都会被池化,然后输入下一层。 

7.  需要训练的权值(参数)的数量只和卷积核尺寸有关,隐含节点(即卷积核要卷积的次数)只和卷积的卷积步长、图像尺寸有关。

个人理解,一个卷积核对整个图像卷积的过程,就像是一个棋子,在整个棋盘上按照步长跳动,每跳动一次,对感知范围内的像素点做一次连接计算。 

8.  CNN在结构上和图像的结构更为接近,都是2D的,因此,早期用在图像上效果很好,但是最近,CNN用于NLP也很热门。

二.程序解析

# coding: utf-8 
 
# In[1]: 
 
from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

补充一下目前三个网络在mnist上的精度分别为:

无隐含层的softmax:91.5%

加入一个全连接隐含层的感知机:98.1%

此cnn:99.07%

和作者的训练结果有细微的差异,可能设备不同吧。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python提示[Errno 32]Broken pipe导致线程crash错误解决方法
Nov 19 Python
Python删除空文件和空文件夹的方法
Jul 14 Python
如何高效使用Python字典的方法详解
Aug 31 Python
python如何把嵌套列表转变成普通列表
Mar 20 Python
谈谈Python中的while循环语句
Mar 10 Python
Python使用正则表达式分割字符串的实现方法
Jul 16 Python
利用python计算时间差(返回天数)
Sep 07 Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 Python
tensorflow 保存模型和取出中间权重例子
Jan 24 Python
Python基于pandas绘制散点图矩阵代码实例
Jun 04 Python
如何将json数据转换为python数据
Sep 04 Python
Pycharm安装python库的方法
Nov 24 Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
You might like
批量修改RAR文件注释的php代码
2010/11/20 PHP
PHP自动重命名文件实现方法
2014/11/04 PHP
node.js中的fs.truncateSync方法使用说明
2014/12/15 Javascript
javascript实现类似超链接的效果
2014/12/26 Javascript
JavaScript中的Math.atan2()方法使用详解
2015/06/15 Javascript
js+css实现上下翻页相册代码分享
2015/08/18 Javascript
学习JavaScript设计模式之策略模式
2016/01/12 Javascript
BootStrap创建响应式导航条实例代码
2016/05/31 Javascript
jquery 动态合并单元格的实现方法
2016/08/26 Javascript
js select实现省市区联动选择
2020/04/17 Javascript
微信小程序 动态绑定事件并实现事件修改样式
2017/04/13 Javascript
vue.draggable实现表格拖拽排序效果
2018/12/01 Javascript
JavaScript刷新页面的几种方法总结
2019/03/28 Javascript
ionic4+angular7+cordova上传图片功能的实例代码
2019/06/19 Javascript
Vue环境搭建+VSCode+Win10的详细教程
2020/08/19 Javascript
原生js实现点击按钮复制内容到剪切板
2020/11/19 Javascript
[47:03]完美世界DOTA2联赛PWL S3 Galaxy Racer vs Phoenix 第二场 12.10
2020/12/13 DOTA
linux系统使用python监控apache服务器进程脚本分享
2014/01/15 Python
Python科学计算环境推荐——Anaconda
2014/06/30 Python
Python 出现错误TypeError: ‘NoneType’ object is not iterable解决办法
2017/01/12 Python
python tensorflow基于cnn实现手写数字识别
2018/01/01 Python
Python爬虫实现(伪)球迷速成
2018/06/10 Python
Python3模拟登录操作实例分析
2019/03/12 Python
Python3匿名函数lambda介绍与使用示例
2019/05/18 Python
对Pytorch神经网络初始化kaiming分布详解
2019/08/18 Python
wxpython布局的实现方法
2019/11/01 Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
2020/01/21 Python
Python实现不规则图形填充的思路
2020/02/02 Python
Python使用configparser读取ini配置文件
2020/05/25 Python
python Cartopy的基础使用详解
2020/11/01 Python
Lombok插件安装(IDEA)及配置jar包使用详解
2020/11/04 Python
如何在Canvas上的图形/图像绑定事件监听的实现
2020/09/16 HTML / CSS
德国大型箱包和皮具商店:Koffer
2019/10/01 全球购物
2014年最新离婚协议书范本
2014/10/11 职场文书
《成长的天空》读后感3篇
2019/12/06 职场文书
测量JavaScript函数的性能各种方式对比
2021/04/27 Javascript