tensorflow实现简单的卷积网络


Posted in Python onMay 24, 2018

使用tensorflow实现一个简单的卷积神经,使用的数据集是MNIST,本节将使用两个卷积层加一个全连接层,构建一个简单有代表性的卷积网络。

代码是按照书上的敲的,第一步就是导入数据库,设置节点的初始值,Tf.nn.conv2d是tensorflow中的2维卷积,参数x是输入,W是卷积的参数,比如【5,5,1,32】,前面两个数字代表卷积核的尺寸,第三个数字代表有几个通道,比如灰度图是1,彩色图是3.最后一个代表卷积的数量,总的实现代码如下:

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

 注意的是书上开始运行的代码是tf.global_variables_initializer().run(),但是在敲到代码中就会报错,也不知道为什么,可能是因为版本的问题吧,上网搜了一下,改为sess.run(tf.initialiaze_all_variables)即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 正则式 概述及常用字符
May 07 Python
Python的shutil模块中文件的复制操作函数详解
Jul 05 Python
Mac 上切换Python多版本
Jun 17 Python
Python Socket编程之多线程聊天室
Jul 28 Python
使用Python实现微信提醒备忘录功能
Dec 04 Python
Python实现的银行系统模拟程序完整案例
Apr 12 Python
python使用tomorrow实现多线程的例子
Jul 20 Python
python多线程与多进程及其区别详解
Aug 08 Python
详解python中docx库的安装过程
Nov 08 Python
flask框架url与重定向操作实例详解
Jan 25 Python
git查看、创建、删除、本地、远程分支方法详解
Feb 18 Python
python 实现倒计时功能(gui界面)
Nov 11 Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
You might like
关于时间计算的结总
2006/12/06 PHP
PHP MYSQL乱码问题,使用SET NAMES utf8校正
2009/11/30 PHP
PHP 时间转换Unix时间戳代码
2010/01/22 PHP
PHP中设置时区,记录日志文件的实现代码
2013/01/07 PHP
php使用APC实现实时上传进度条功能
2015/10/26 PHP
Laravel实现自定义错误输出内容的方法
2016/10/10 PHP
探究Laravel使用env函数读取环境变量为null的问题
2016/12/06 PHP
php 读写json文件及修改json的方法
2018/03/07 PHP
[原创]网络复制内容时常用的正则+editplus
2006/11/30 Javascript
30个让人兴奋的视差滚动(Parallax Scrolling)效果网站
2012/03/04 Javascript
JS完整获取IE浏览器信息包括类型、版本、语言等等
2014/05/22 Javascript
Nodejs极简入门教程(一):模块机制
2014/10/25 NodeJs
jquery+css3实现会动的小圆圈效果
2016/01/27 Javascript
使用jquery提交form表单并自定义action的方法
2016/05/25 Javascript
原生js实现无缝轮播图效果
2017/01/11 Javascript
JavaScript实现定时页面跳转功能示例
2017/02/14 Javascript
详解利用 Vue.js 实现前后端分离的RBAC角色权限管理
2017/09/15 Javascript
JavaScript伪数组用法实例分析
2017/12/22 Javascript
ES6下子组件调用父组件的方法(推荐)
2018/02/23 Javascript
Vue自定义指令上报Google Analytics事件统计的方法
2019/02/25 Javascript
基于js实现抽红包并分配代码实例
2019/09/19 Javascript
解决在Vue中使用axios用form表单出现的问题
2019/10/30 Javascript
详解supervisor使用教程
2017/11/21 Python
python通过伪装头部数据抵抗反爬虫的实例
2018/05/07 Python
用python实现一个简单的验证码
2020/12/09 Python
Python中Qslider控件实操详解
2021/02/20 Python
AmazeUI的JS表单验证框架实战示例分享
2020/08/21 HTML / CSS
Corelle官方网站:购买康宁餐具
2016/11/02 全球购物
澳大利亚最好的在线时尚精品店:Princess Polly
2018/01/03 全球购物
英国卫浴商店:Ergonomic Design
2019/09/22 全球购物
分厂厂长岗位职责
2013/12/29 职场文书
2014春晚主持词
2014/03/25 职场文书
自主招生自荐信怎么写
2015/03/24 职场文书
2015年高校辅导员工作总结
2015/04/20 职场文书
小学生禁毒教育心得体会
2016/01/15 职场文书
react 项目中引入图片的几种方式
2021/06/02 Javascript