tensorflow实现简单的卷积网络


Posted in Python onMay 24, 2018

使用tensorflow实现一个简单的卷积神经,使用的数据集是MNIST,本节将使用两个卷积层加一个全连接层,构建一个简单有代表性的卷积网络。

代码是按照书上的敲的,第一步就是导入数据库,设置节点的初始值,Tf.nn.conv2d是tensorflow中的2维卷积,参数x是输入,W是卷积的参数,比如【5,5,1,32】,前面两个数字代表卷积核的尺寸,第三个数字代表有几个通道,比如灰度图是1,彩色图是3.最后一个代表卷积的数量,总的实现代码如下:

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

 注意的是书上开始运行的代码是tf.global_variables_initializer().run(),但是在敲到代码中就会报错,也不知道为什么,可能是因为版本的问题吧,上网搜了一下,改为sess.run(tf.initialiaze_all_variables)即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Flask框架Flask-Principal基本用法实例分析
Jul 23 Python
使用Python 正则匹配两个特定字符之间的字符方法
Dec 24 Python
Pandas读写CSV文件的方法示例
Mar 27 Python
python程序变成软件的实操方法
Jun 24 Python
Django的用户模块与权限系统的示例代码
Jul 24 Python
用Python徒手撸一个股票回测框架搭建【推荐】
Aug 05 Python
python3 selenium自动化 下拉框定位的例子
Aug 23 Python
python实现文字版扫雷
Apr 24 Python
python mysql自增字段AUTO_INCREMENT值的修改方式
May 18 Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 Python
python的数学算法函数及公式用法
Nov 18 Python
用python对excel进行操作(读,写,修改)
Dec 25 Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
You might like
laravel实现Auth认证,登录、注册后的页面回跳方法
2019/09/30 PHP
基于PHP实现解密或加密Cloudflar邮箱保护
2020/06/24 PHP
自动更新作用
2006/10/08 Javascript
学习从实践开始之jQuery插件开发 菜单插件开发
2012/05/03 Javascript
jQuery $.get 的妙用 访问本地文本文件
2012/07/12 Javascript
图片动画横条广告带上下滚动的JS代码
2013/10/25 Javascript
JS实现让网页背景图片斜向移动的方法
2015/02/25 Javascript
jquery实现简单的二级导航下拉菜单效果
2015/09/07 Javascript
使用jquery插件qrcode生成二维码
2015/10/22 Javascript
java中String类型变量的赋值问题介绍
2016/03/23 Javascript
JavaScript实现设计模式中的单例模式的一些技巧总结
2016/05/17 Javascript
使用JS实现图片展示瀑布流效果(简单实例)
2016/09/06 Javascript
使用node.js搭建服务器
2017/05/20 Javascript
Bootstrap Table 删除和批量删除
2017/09/22 Javascript
微信小程序之onLaunch与onload异步问题详解
2019/03/28 Javascript
vue cli3 调用百度翻译API翻译页面的实现示例
2019/09/13 Javascript
javascript实现简易数码时钟
2020/03/30 Javascript
详解JavaScript中的Object.is()与"==="运算符总结
2020/06/17 Javascript
vue-cli3 引入 font-awesome的操作
2020/08/11 Javascript
[55:45]DOTA2上海特级锦标赛D组败者赛 Liquid VS COL第一局
2016/02/28 DOTA
编写Python脚本使得web页面上的代码高亮显示
2015/04/24 Python
Python 读取指定文件夹下的所有图像方法
2018/04/27 Python
Django配置MySQL数据库的完整步骤
2019/09/07 Python
Python写出新冠状病毒确诊人数地图的方法
2020/02/12 Python
pycharm-professional-2020.1下载与激活的教程
2020/09/21 Python
appium+python自动化配置(adk、jdk、node.js)
2020/11/17 Python
英国舒适型鞋履品牌:FitFlop
2017/05/17 全球购物
财务信息服务专业自荐书范文
2014/02/08 职场文书
趣味比赛活动方案
2014/02/15 职场文书
本科毕业生求职信
2014/06/15 职场文书
顶岗实习协议书
2015/01/29 职场文书
证券区域经理岗位职责
2015/04/10 职场文书
财务统计员岗位职责
2015/04/14 职场文书
浅谈css实现背景颜色半透明的两种方法
2021/12/06 HTML / CSS
vue实现Toast组件轻提示
2022/04/10 Vue.js
使用 Docker Compose 构建复杂的多容器App
2022/04/30 Servers