tensorflow实现简单的卷积网络


Posted in Python onMay 24, 2018

使用tensorflow实现一个简单的卷积神经,使用的数据集是MNIST,本节将使用两个卷积层加一个全连接层,构建一个简单有代表性的卷积网络。

代码是按照书上的敲的,第一步就是导入数据库,设置节点的初始值,Tf.nn.conv2d是tensorflow中的2维卷积,参数x是输入,W是卷积的参数,比如【5,5,1,32】,前面两个数字代表卷积核的尺寸,第三个数字代表有几个通道,比如灰度图是1,彩色图是3.最后一个代表卷积的数量,总的实现代码如下:

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

 注意的是书上开始运行的代码是tf.global_variables_initializer().run(),但是在敲到代码中就会报错,也不知道为什么,可能是因为版本的问题吧,上网搜了一下,改为sess.run(tf.initialiaze_all_variables)即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python struct模块解析
Jun 12 Python
Python脚本实现12306火车票查询系统
Sep 30 Python
python绘制中国大陆人口热力图
Nov 07 Python
python事件驱动event实现详解
Nov 21 Python
Linux下Python安装完成后使用pip命令的详细教程
Nov 22 Python
python之验证码生成(gvcode与captcha)
Jan 02 Python
python自动发微信监控报警
Sep 06 Python
Django之form组件自动校验数据实现
Jan 14 Python
使用python脚本自动生成K8S-YAML的方法示例
Jul 12 Python
如何用python免费看美剧
Aug 11 Python
浅析python 通⽤爬⾍和聚焦爬⾍
Sep 28 Python
pytorch 6 batch_train 批训练操作
May 28 Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
You might like
PHP中调用JAVA
2006/10/09 PHP
php中截取中文字符串的代码小结
2011/07/17 PHP
php中拷贝构造函数、赋值运算符重载
2012/07/25 PHP
thinkphp特殊标签用法概述
2014/11/24 PHP
利用Homestead快速运行一个Laravel项目的方法详解
2017/11/14 PHP
Laravel下生成验证码的类
2017/11/15 PHP
php实现统计IP数及在线人数的示例代码
2020/07/22 PHP
线路分流自动智能跳转代码,自动选择最快镜像网站(js)
2011/10/31 Javascript
调试Node.JS的辅助工具(NodeWatcher)
2012/01/04 Javascript
JavaScript中对象属性的添加和删除示例
2014/05/12 Javascript
jQuery学习笔记之 Ajax操作篇(二) - 数据传递
2014/06/23 Javascript
jquery实现树形菜单完整代码
2015/12/29 Javascript
jQuery height()、innerHeight()、outerHeight()函数的区别详解
2016/05/23 Javascript
vue如何引用其他组件(css和js)
2017/04/13 Javascript
微信小程序版翻牌小游戏
2018/01/26 Javascript
layer实现登录弹框,登录成功后关闭弹框并调用父窗口的例子
2019/09/11 Javascript
JS实现商城秒杀倒计时功能(动态设置秒杀时间)
2019/12/12 Javascript
[35:26]DOTA2上海特级锦标赛B组小组赛#2 VG VS Fnatic第三局
2016/02/26 DOTA
python使用PyFetion来发送短信的例子
2014/04/22 Python
Python Tkinter简单布局实例教程
2014/09/03 Python
Python yield 使用浅析
2015/05/28 Python
正确理解python中的关键字“with”与上下文管理器
2017/04/21 Python
python3读取MySQL-Front的MYSQL密码
2017/05/03 Python
Python批处理删除和重命名文件夹的实例
2018/07/11 Python
对django中render()与render_to_response()的区别详解
2018/10/16 Python
Python笔记之facade模式
2019/11/20 Python
python如何更新包
2020/06/11 Python
如何用python免费看美剧
2020/08/11 Python
如何使用 Python 读取文件和照片的创建日期
2020/09/05 Python
戴尔美国官网:Dell
2016/08/31 全球购物
台湾家适得:Homeget
2019/02/11 全球购物
shell变量的作用空间是什么
2013/08/17 面试题
人民教师的自我评价分享
2014/02/21 职场文书
高考寄语大全
2014/04/08 职场文书
党政领导班子群众路线对照检查材料思想汇报
2014/09/27 职场文书
生日宴会祝酒词
2015/08/10 职场文书