tensorflow实现简单的卷积网络


Posted in Python onMay 24, 2018

使用tensorflow实现一个简单的卷积神经,使用的数据集是MNIST,本节将使用两个卷积层加一个全连接层,构建一个简单有代表性的卷积网络。

代码是按照书上的敲的,第一步就是导入数据库,设置节点的初始值,Tf.nn.conv2d是tensorflow中的2维卷积,参数x是输入,W是卷积的参数,比如【5,5,1,32】,前面两个数字代表卷积核的尺寸,第三个数字代表有几个通道,比如灰度图是1,彩色图是3.最后一个代表卷积的数量,总的实现代码如下:

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

 注意的是书上开始运行的代码是tf.global_variables_initializer().run(),但是在敲到代码中就会报错,也不知道为什么,可能是因为版本的问题吧,上网搜了一下,改为sess.run(tf.initialiaze_all_variables)即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用strip()方法删除字符串中空格的教程
May 20 Python
python操作oracle的完整教程分享
Jan 30 Python
详解Python循环作用域与闭包
Mar 21 Python
python django下载大的csv文件实现方法分析
Jul 19 Python
Django上线部署之IIS的配置方法
Aug 22 Python
python文字转语音的实例代码分析
Nov 12 Python
python实现指定ip端口扫描方式
Dec 17 Python
python读取多层嵌套文件夹中的文件实例
Feb 27 Python
Django项目uwsgi+Nginx保姆级部署教程实现
Apr 19 Python
Python使用xlrd实现读取合并单元格
Jul 09 Python
Python代码覆盖率统计工具coverage.py用法详解
Nov 25 Python
python中reload重载实例用法
Dec 15 Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
You might like
php&java(一)
2006/10/09 PHP
php中检查文件或目录是否存在的代码小结
2012/10/22 PHP
php中将数组转成字符串并保存到数据库中的函数代码
2013/09/29 PHP
Thinkphp连表查询及数据导出方法示例
2016/10/15 PHP
一个网马的tips实现分析
2010/11/28 Javascript
IE事件对象(The Internet Explorer Event Object)
2012/06/27 Javascript
js 获取radio按钮值的实例
2013/08/17 Javascript
javascript实现验证身份证号的有效性并提示
2015/04/30 Javascript
图解Sublime Text3使用技巧
2015/12/21 Javascript
详解Document.Cookie
2015/12/25 Javascript
分享两段简单的JS代码防止SQL注入
2016/04/12 Javascript
JS组件Bootstrap实现下拉菜单效果代码
2016/04/26 Javascript
js根据手机客户端浏览器类型,判断跳转官网/手机网站多个实例代码
2016/04/30 Javascript
省市选择的简单实现(基于zepto.js)
2016/06/21 Javascript
JavaScript中String对象的方法介绍
2017/01/04 Javascript
javascript数据结构之串的概念与用法分析
2017/04/12 Javascript
React-Native使用Mobx实现购物车功能
2017/09/14 Javascript
Angular父组件调用子组件的方法
2018/04/02 Javascript
vue.js实现左边导航切换右边内容
2019/10/21 Javascript
JavaScript设计模型Iterator实例解析
2020/01/22 Javascript
Vue 修改网站图标的方法
2020/12/31 Vue.js
零基础写python爬虫之爬虫编写全记录
2014/11/06 Python
python获取各操作系统硬件信息的方法
2015/06/03 Python
python 获取url中的参数列表实例
2018/12/18 Python
Python求解正态分布置信区间教程
2019/11/20 Python
PyCharm 无法 import pandas 程序卡住的解决方式
2020/03/09 Python
Python如何使用PIL Image制作GIF图片
2020/05/16 Python
在css3中background-clip属性与background-origin属性的用法介绍
2012/11/13 HTML / CSS
eBay意大利购物网站:eBay.it
2019/09/04 全球购物
解释一下ruby中的特殊方法与特殊类
2013/02/26 面试题
银行简历自我评价
2014/02/11 职场文书
班风学风建设方案
2014/05/06 职场文书
住房抵押登记委托书
2014/09/27 职场文书
作文评语集锦
2014/12/25 职场文书
Python编程根据字典列表相同键的值进行合并
2021/10/05 Python
python通过新建环境安装tfx的问题
2022/05/20 Python