tensorflow实现简单的卷积网络


Posted in Python onMay 24, 2018

使用tensorflow实现一个简单的卷积神经,使用的数据集是MNIST,本节将使用两个卷积层加一个全连接层,构建一个简单有代表性的卷积网络。

代码是按照书上的敲的,第一步就是导入数据库,设置节点的初始值,Tf.nn.conv2d是tensorflow中的2维卷积,参数x是输入,W是卷积的参数,比如【5,5,1,32】,前面两个数字代表卷积核的尺寸,第三个数字代表有几个通道,比如灰度图是1,彩色图是3.最后一个代表卷积的数量,总的实现代码如下:

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

 注意的是书上开始运行的代码是tf.global_variables_initializer().run(),但是在敲到代码中就会报错,也不知道为什么,可能是因为版本的问题吧,上网搜了一下,改为sess.run(tf.initialiaze_all_variables)即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现对PPT文件进行截图操作的方法
Apr 28 Python
python访问系统环境变量的方法
Apr 29 Python
Python随机生成信用卡卡号的实现方法
May 14 Python
Python 处理数据的实例详解
Aug 10 Python
python自动化脚本安装指定版本python环境详解
Sep 14 Python
pandas 转换成行列表进行读取与Nan处理的方法
Oct 30 Python
python协程之动态添加任务的方法
Feb 19 Python
python安装numpy和pandas的方法步骤
May 27 Python
Python Django 简单分页的实现代码解析
Aug 21 Python
Python包和模块的分发详细介绍
Jun 19 Python
python 模拟登录B站的示例代码
Dec 15 Python
详解win10下pytorch-gpu安装以及CUDA详细安装过程
Jan 28 Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
You might like
php 删除无限级目录与文件代码共享
2008/11/22 PHP
解决ajax+php中文乱码的方法详解
2013/06/09 PHP
浅析HTTP消息头网页缓存控制以及header常用指令介绍
2013/06/28 PHP
php实现将数组转换为XML的方法
2015/03/09 PHP
php-fpm超时时间设置request_terminate_timeout资源问题分析
2019/09/27 PHP
PHP sdk文档处理常用代码示例解析
2020/12/09 PHP
JavaScript 动态生成方法的例子
2009/07/22 Javascript
javascript 设为首页与加入收藏兼容多浏览器代码
2011/01/11 Javascript
jQuery getJSON()+.ashx 实现分页(改进版)
2013/03/28 Javascript
JS格式化数字金额用逗号隔开保留两位小数
2013/10/18 Javascript
JavaScript异步编程Promise模式的6个特性
2014/04/03 Javascript
JS获取当前日期时间并定时刷新示例
2021/03/04 Javascript
JavaScript中的分号插入机制详细介绍
2015/02/11 Javascript
Angular 理解module和injector,即依赖注入
2016/09/07 Javascript
JavaScript实现倒计时跳转页面功能【实用】
2016/12/13 Javascript
Vue生命周期示例详解
2017/04/12 Javascript
Python正则表达式匹配HTML页面编码
2015/04/08 Python
Python实现扣除个人税后的工资计算器示例
2018/03/26 Python
Python实现中英文全文搜索的示例
2020/12/04 Python
Canvas与Image互相转换示例代码
2013/08/09 HTML / CSS
html5使用canvas画三角形
2014/12/15 HTML / CSS
拉斯维加斯城市观光通行证:Las Vegas Pass
2019/05/21 全球购物
什么是TCP/IP
2014/07/27 面试题
What's the difference between an interface and abstract class? (接口与抽象类有什么区别)
2012/10/29 面试题
介绍一下Ruby中的对象,属性和方法
2012/07/11 面试题
初婚未育证明
2014/01/15 职场文书
给儿子的表扬信
2014/01/15 职场文书
浙大毕业生自荐信
2014/01/26 职场文书
优秀民警事迹材料
2014/01/29 职场文书
电子商务专业毕业生求职信
2014/06/12 职场文书
2015驻村干部工作总结
2015/04/07 职场文书
卫生院义诊活动总结
2015/05/07 职场文书
2015年学校总务处工作总结
2015/05/19 职场文书
校运会广播稿
2015/08/19 职场文书
apache基于端口创建虚拟主机的示例
2021/04/24 Servers
【海涛教你打dota】体验一超神发条:咱是抢盾专业户
2022/04/01 DOTA