tensorflow实现简单的卷积网络


Posted in Python onMay 24, 2018

使用tensorflow实现一个简单的卷积神经,使用的数据集是MNIST,本节将使用两个卷积层加一个全连接层,构建一个简单有代表性的卷积网络。

代码是按照书上的敲的,第一步就是导入数据库,设置节点的初始值,Tf.nn.conv2d是tensorflow中的2维卷积,参数x是输入,W是卷积的参数,比如【5,5,1,32】,前面两个数字代表卷积核的尺寸,第三个数字代表有几个通道,比如灰度图是1,彩色图是3.最后一个代表卷积的数量,总的实现代码如下:

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

 注意的是书上开始运行的代码是tf.global_variables_initializer().run(),但是在敲到代码中就会报错,也不知道为什么,可能是因为版本的问题吧,上网搜了一下,改为sess.run(tf.initialiaze_all_variables)即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现线程池的方法
Jun 30 Python
python版简单工厂模式
Oct 16 Python
Python实现按特定格式对文件进行读写的方法示例
Nov 30 Python
tensorflow更改变量的值实例
Jul 30 Python
python多线程调用exit无法退出的解决方法
Feb 18 Python
Python Django实现layui风格+django分页功能的例子
Aug 29 Python
Python imread、newaxis用法详解
Nov 04 Python
python词云库wordCloud使用方法详解(解决中文乱码)
Feb 17 Python
解决python对齐错误的方法
Jul 16 Python
python 实现简单的计算器(gui界面)
Nov 11 Python
浅析Python模块之间的相互引用问题
Feb 26 Python
详解PyTorch模型保存与加载
Apr 28 Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
You might like
浅谈php处理后端&接口访问超时的解决方法
2016/10/29 PHP
jQuery 验证插件 Web前端设计模式(asp.net)
2010/10/17 Javascript
Jquery 表单验证类介绍与实例
2013/06/09 Javascript
基于javascript 闭包基础分享
2013/07/10 Javascript
jquery ajax的success回调函数中实现按钮置灰倒计时
2013/11/19 Javascript
在NodeJS中启用ECMAScript 6小结(windos以及Linux)
2014/07/15 NodeJs
jquery复选框多选赋值给文本框的方法
2015/01/27 Javascript
运行Node.js的IIS扩展iisnode安装配置笔记
2015/03/02 Javascript
js实现为a标签添加事件的方法(使用闭包循环)
2016/08/02 Javascript
ES6中module模块化开发实例浅析
2017/04/06 Javascript
JavaScript闭包_动力节点Java学院整理
2017/06/27 Javascript
基于node.js express mvc轻量级框架实践
2017/09/14 Javascript
vue2.0+ 从插件开发到npm发布的示例代码
2018/04/28 Javascript
VUE 自定义组件模板的方法详解
2019/08/30 Javascript
在VUE中实现文件下载并判断状态的方法
2019/11/08 Javascript
基于vue实现图片验证码倒计时60s功能
2019/12/10 Javascript
小程序中的箭头函数的具体使用
2020/06/19 Javascript
JS实现超级好看的鼠标小尾巴特效
2020/12/01 Javascript
[48:52]DOTA2上海特级锦标赛A组小组赛#2 Secret VS CDEC第一局
2016/02/25 DOTA
Python输出由1,2,3,4组成的互不相同且无重复的三位数
2018/02/01 Python
pandas.DataFrame选取/排除特定行的方法
2018/07/03 Python
Python线程下使用锁的技巧分享
2018/09/13 Python
总结Pyinstaller的坑及终极解决方法(小结)
2020/09/21 Python
利用html5 file api读取本地文件示例(如图片、PDF等)
2018/03/07 HTML / CSS
小学生家长评语大全
2014/02/10 职场文书
《大禹治水》教学反思
2014/04/27 职场文书
3分钟演讲稿
2014/04/30 职场文书
舞蹈教育学专业求职信
2014/06/29 职场文书
社区党员志愿服务活动方案
2014/08/18 职场文书
五星红旗迎风飘扬观后感
2015/06/17 职场文书
六一儿童节致辞
2015/07/31 职场文书
建房合同协议书
2016/03/21 职场文书
Java SSH 秘钥连接mysql数据库的方法
2021/06/28 Java/Android
oracle索引总结
2021/09/25 Oracle
【海涛dota解说】DCG联赛第一周 LGD VS DH
2022/04/01 DOTA
使用 MybatisPlus 连接 SqlServer 数据库解决 OFFSET 分页问题
2022/04/22 SQL Server