tensorflow实现简单的卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了Android九宫格图片展示的具体代码,供大家参考,具体内容如下

一.知识点总结

1.  卷积神经网络出现的初衷是降低对图像的预处理,避免建立复杂的特征工程。因为卷积神经网络在训练的过程中,自己会提取特征。

2.   灵感来自于猫的视觉皮层研究,每一个视觉神经元只会处理一小块区域的视觉图像,即感知野。放到卷积神经网络里就是每一个隐含节点只与设定范围内的像素点相连(设定范围就是卷积核的尺寸),而全连接层是每个像素点与每个隐含节点相连。这种感知野也称之为局部感知。

例如,一张1000*1000的图片,如果隐含层有100*100个隐含节点全连接,则需要1000*1000*100*100+100*100个参数,而如果有10*10的范围局部感知,用同样多的隐含节点,只需要10*10*100*100+100*100个参数。

3.  把卷积的过程称作卷积滤波,除了上面的局部感知,卷积滤波还有一个化简操作——权值共享。即一个卷积滤波中的所有隐含节点与感知图像连接的权值是一样的,这样,上述例子的参数减少为10*10+100*100个了。W的数量等于感知范围的尺寸。

4.  为了抗变形和减小复杂度,卷积层同时还要做激活和池化。激活函数前一章已经弄明白了,池化,相当于降采样,将n*n的像素区域采样为m*m区域,m通常小于n。通常选择最大池化,即选择区域内的最大像素点。 

5.  总结来讲,卷积有三个要点:局部连接、权值共享、池化降采样。一个卷积过程包含三个步骤:卷积滤波、激活、池化。 

6.  卷积滤波中的卷积范围可以用一个词来代替——卷积核,卷积核等同于卷积滤波中的一个隐含节点感知范围。由于权值共享,相当于一个卷积核对整个图像做多次小范围滤波,每滤一次波生成一个小的特征图像,多次滤波后将所有小特征图像组合起来,生成了对整个图像的feature map。通常,一个卷积滤波过程有多个卷积核卷积,生成多张feature map。

所有的feature map都会被池化,然后输入下一层。 

7.  需要训练的权值(参数)的数量只和卷积核尺寸有关,隐含节点(即卷积核要卷积的次数)只和卷积的卷积步长、图像尺寸有关。

个人理解,一个卷积核对整个图像卷积的过程,就像是一个棋子,在整个棋盘上按照步长跳动,每跳动一次,对感知范围内的像素点做一次连接计算。 

8.  CNN在结构上和图像的结构更为接近,都是2D的,因此,早期用在图像上效果很好,但是最近,CNN用于NLP也很热门。

二.程序解析

# coding: utf-8 
 
# In[1]: 
 
from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

补充一下目前三个网络在mnist上的精度分别为:

无隐含层的softmax:91.5%

加入一个全连接隐含层的感知机:98.1%

此cnn:99.07%

和作者的训练结果有细微的差异,可能设备不同吧。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3实现抓取javascript动态生成的html网页功能示例
Aug 22 Python
python时间日期函数与利用pandas进行时间序列处理详解
Mar 13 Python
详解Python 协程的详细用法使用和例子
Jun 15 Python
Python集中化管理平台Ansible介绍与YAML简介
Jun 12 Python
python之生产者消费者模型实现详解
Jul 27 Python
python对csv文件追加写入列的方法
Aug 01 Python
基于TensorBoard中graph模块图结构分析
Feb 15 Python
Python中and和or如何使用
May 28 Python
如何基于Python代码实现高精度免费OCR工具
Jun 18 Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 Python
python删除csv文件的行列
Apr 06 Python
总结Pyinstaller打包的高级用法
Jun 28 Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
You might like
用Apache反向代理设置对外的WWW和文件服务器
2006/10/09 PHP
PHP file_get_contents 函数超时的几种解决方法
2009/07/30 PHP
理解PHP中的stdClass类
2014/04/18 PHP
php中的观察者模式简单实例
2015/01/20 PHP
php链表用法实例分析
2015/07/09 PHP
Windows Server 2008 R2和2012中PHP连接MySQL过慢的解决方法
2016/07/02 PHP
yii2.0整合阿里云oss删除单个文件的方法
2017/09/19 PHP
php如何计算两坐标点之间的距离
2018/12/29 PHP
jQuery 加上最后自己的验证
2009/11/04 Javascript
Dojo 学习笔记入门篇 First Dojo Example
2009/11/15 Javascript
IE6下JS动态设置图片src地址问题
2010/01/08 Javascript
文本框输入时 实现自动提示(像百度、google一样)
2012/04/05 Javascript
jquery 操作DOM的基本用法分享
2012/04/05 Javascript
限制上传文件大小和格式的jQuery插件实例
2015/01/24 Javascript
javascript实现table表格隔行变色的方法
2015/05/13 Javascript
jQuery实现网站添加高亮突出显示效果的方法
2015/06/26 Javascript
JS打开摄像头并截图上传示例
2017/02/18 Javascript
Vue.js组件通信的几种姿势
2017/10/23 Javascript
VueJs组件之父子通讯的方式
2018/05/06 Javascript
React Router V4使用指南(精讲)
2018/09/17 Javascript
一个因@click.stop引发的bug的解决
2019/01/08 Javascript
Python天气预报采集器实现代码(网页爬虫)
2012/10/07 Python
解决python写的windows服务不能启动的问题
2014/04/15 Python
python实现井字棋游戏
2020/03/30 Python
python能否java成为主流语言吗
2020/06/22 Python
python对批量WAV音频进行等长分割的方法实现
2020/09/25 Python
python3中celery异步框架简单使用+守护进程方式启动
2021/01/20 Python
台湾百利市购物中心:e-Payless
2017/08/16 全球购物
开发中都用到了那些设计模式?用在什么场合?
2014/08/21 面试题
办公室文员自荐书
2014/02/03 职场文书
市场拓展计划书
2014/05/03 职场文书
保护地球的标语
2014/06/17 职场文书
旅游与酒店管理专业求职信
2014/07/21 职场文书
在Django中使用MQTT的方法
2021/05/10 Python
mysql 如何获取两个集合的交集/差集/并集
2021/06/08 MySQL
WinServer2012搭建DNS服务器的方法步骤
2022/06/10 Servers