tensorflow实现简单的卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了Android九宫格图片展示的具体代码,供大家参考,具体内容如下

一.知识点总结

1.  卷积神经网络出现的初衷是降低对图像的预处理,避免建立复杂的特征工程。因为卷积神经网络在训练的过程中,自己会提取特征。

2.   灵感来自于猫的视觉皮层研究,每一个视觉神经元只会处理一小块区域的视觉图像,即感知野。放到卷积神经网络里就是每一个隐含节点只与设定范围内的像素点相连(设定范围就是卷积核的尺寸),而全连接层是每个像素点与每个隐含节点相连。这种感知野也称之为局部感知。

例如,一张1000*1000的图片,如果隐含层有100*100个隐含节点全连接,则需要1000*1000*100*100+100*100个参数,而如果有10*10的范围局部感知,用同样多的隐含节点,只需要10*10*100*100+100*100个参数。

3.  把卷积的过程称作卷积滤波,除了上面的局部感知,卷积滤波还有一个化简操作——权值共享。即一个卷积滤波中的所有隐含节点与感知图像连接的权值是一样的,这样,上述例子的参数减少为10*10+100*100个了。W的数量等于感知范围的尺寸。

4.  为了抗变形和减小复杂度,卷积层同时还要做激活和池化。激活函数前一章已经弄明白了,池化,相当于降采样,将n*n的像素区域采样为m*m区域,m通常小于n。通常选择最大池化,即选择区域内的最大像素点。 

5.  总结来讲,卷积有三个要点:局部连接、权值共享、池化降采样。一个卷积过程包含三个步骤:卷积滤波、激活、池化。 

6.  卷积滤波中的卷积范围可以用一个词来代替——卷积核,卷积核等同于卷积滤波中的一个隐含节点感知范围。由于权值共享,相当于一个卷积核对整个图像做多次小范围滤波,每滤一次波生成一个小的特征图像,多次滤波后将所有小特征图像组合起来,生成了对整个图像的feature map。通常,一个卷积滤波过程有多个卷积核卷积,生成多张feature map。

所有的feature map都会被池化,然后输入下一层。 

7.  需要训练的权值(参数)的数量只和卷积核尺寸有关,隐含节点(即卷积核要卷积的次数)只和卷积的卷积步长、图像尺寸有关。

个人理解,一个卷积核对整个图像卷积的过程,就像是一个棋子,在整个棋盘上按照步长跳动,每跳动一次,对感知范围内的像素点做一次连接计算。 

8.  CNN在结构上和图像的结构更为接近,都是2D的,因此,早期用在图像上效果很好,但是最近,CNN用于NLP也很热门。

二.程序解析

# coding: utf-8 
 
# In[1]: 
 
from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

补充一下目前三个网络在mnist上的精度分别为:

无隐含层的softmax:91.5%

加入一个全连接隐含层的感知机:98.1%

此cnn:99.07%

和作者的训练结果有细微的差异,可能设备不同吧。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3基础之基本数据类型概述
Aug 13 Python
python自然语言编码转换模块codecs介绍
Apr 08 Python
Python全局变量操作详解
Apr 14 Python
django使用图片延时加载引起后台404错误
Apr 18 Python
python中hashlib模块用法示例
Oct 30 Python
Python3基础教程之递归函数简单示例
Jun 07 Python
pytz格式化北京时间多出6分钟问题的解决方法
Jun 21 Python
python 设置xlabel,ylabel 坐标轴字体大小,字体类型
Jul 23 Python
Python generator生成器和yield表达式详解
Aug 08 Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 Python
Pytorch 实现数据集自定义读取
Jan 18 Python
python神经网络学习 使用Keras进行回归运算
May 04 Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
You might like
php zip文件解压类代码
2009/12/02 PHP
PHP IE中下载附件问题解决方法
2014/01/07 PHP
php简单实现查询数据库返回json数据
2015/04/16 PHP
Zend Framework框架之Zend_Mail实现发送Email邮件验证功能及解决标题乱码的方法
2016/03/21 PHP
php 函数中静态变量使用的问题实例分析
2020/03/05 PHP
Javascript计算时间差的函数分享
2011/07/04 Javascript
JQuery入门—编写一个简单的JQuery应用案例
2013/01/03 Javascript
13 款最热门的 jQuery 图像 360 度旋转插件推荐
2014/12/09 Javascript
使用canvas实现仿新浪微博头像截取上传功能
2015/09/02 Javascript
深入理解AngularJs-scope的脏检查(一)
2017/06/19 Javascript
angular写一个列表的选择全选交互组件的示例
2018/01/22 Javascript
jQuery实现简单复制json对象和json对象集合操作示例
2018/07/09 jQuery
微信小程序中遇到的iOS兼容性问题小结
2018/11/14 Javascript
微信小程序实现锚点跳转
2020/11/23 Javascript
python with statement 进行文件操作指南
2014/08/22 Python
Python字符串特性及常用字符串方法的简单笔记
2016/01/04 Python
pytz格式化北京时间多出6分钟问题的解决方法
2019/06/21 Python
Python requests获取网页常用方法解析
2020/02/20 Python
python实现简单井字棋游戏
2020/03/04 Python
python中关于数据类型的学习笔记
2020/07/19 Python
移动端Web页面的CSS3 flex布局快速上手指南
2016/05/31 HTML / CSS
Clarisonic美国官网:科莱丽声波洁面仪
2017/10/12 全球购物
澳大利亚香水在线:Price Rite Mart
2017/12/28 全球购物
旧时光糖果:Old Time Candy
2018/02/05 全球购物
英国建筑用品在线:Building Supplies Online(BSO)
2018/04/30 全球购物
美国保健品专家:Life Extension
2018/05/04 全球购物
介绍一下Java中标识符的命名规则
2014/02/03 面试题
翻译专业应届生求职信
2013/11/23 职场文书
资产经营总监岗位职责范文
2013/12/01 职场文书
研究生毕业自我鉴定范文
2014/03/27 职场文书
《莫泊桑拜师》教学反思
2014/04/23 职场文书
大学生精神文明先进个人事迹材料
2014/05/02 职场文书
电子专业毕业生自荐信
2014/05/25 职场文书
2015暑假假期总结
2015/07/13 职场文书
简历自我评价:教师师德表现自我评价
2019/04/24 职场文书
Zabbix6通过ODBC方式监控Oracle 19C的详细过程
2022/09/23 Servers