使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python base64 decode incorrect padding错误解决方法
Jan 08 Python
Python使用Flask框架获取当前查询参数的方法
Mar 21 Python
深入分析python中整型不会溢出问题
Jun 18 Python
python检测主机的连通性并记录到文件的实例
Jun 21 Python
Django中create和save方法的不同
Aug 13 Python
Python中字典与恒等运算符的用法分析
Aug 22 Python
jenkins配置python脚本定时任务过程图解
Oct 29 Python
Python实现报警信息实时发送至邮箱功能(实例代码)
Nov 11 Python
给Python学习者的文件读写指南(含基础与进阶)
Jan 29 Python
Django通过json格式收集主机信息
May 29 Python
python xlwt模块的使用解析
Apr 13 Python
python pandas 解析(读取、写入)CSV 文件的操作方法
Dec 24 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
Linux下PHP安装mcrypt扩展模块笔记
2014/09/10 PHP
Yii核心组件AssetManager原理分析
2014/12/02 PHP
php+Mysqli利用事务处理转账问题实例
2015/02/11 PHP
详解PHP的Yii框架中扩展的安装与使用
2016/04/01 PHP
PHP实现的登录页面信息提示功能示例
2017/07/24 PHP
thinkphp框架表单数组实现图片批量上传功能示例
2020/04/04 PHP
js实现权限树的更新权限时的全选全消功能
2009/02/17 Javascript
提高javascript效率 一次判断,而不要次次判断
2012/03/30 Javascript
javascript解析json数据的3种方式
2014/05/08 Javascript
jQuery使用$.get()方法从服务器文件载入数据实例
2015/03/25 Javascript
javascript实现百度地图鼠标滑动事件显示、隐藏
2015/04/02 Javascript
JS模仿腾讯图片站的图片翻页按钮效果完整实例
2016/06/21 Javascript
easyui datebox 时间限制,datebox开始时间限制结束时间,datebox截止日期比起始日期大的实现代码
2017/01/12 Javascript
vue 计时器组件的实现代码
2017/09/14 Javascript
前端MVVM框架解析之双向绑定
2018/01/24 Javascript
对Vue2 自定义全局指令Vue.directive和指令的生命周期介绍
2018/08/30 Javascript
Vue infinite update loop的问题解决
2019/04/23 Javascript
Async/Await替代Promise的6个理由
2019/06/15 Javascript
vue中使用element ui的弹窗与echarts之间的问题详解
2019/10/25 Javascript
OpenLayers3加载常用控件使用方法详解
2020/09/25 Javascript
tensorflow: 查看 tensor详细数值方法
2018/06/13 Python
Python+OpenCV实现图像融合的原理及代码
2018/12/03 Python
Python 获取div标签中的文字实例
2018/12/20 Python
Python代码实现http/https代理服务器的脚本
2019/08/12 Python
详解字符串在Python内部是如何省内存的
2020/02/03 Python
世界上最大的各式箱包网络零售店:eBag
2016/07/21 全球购物
100%羊绒:NakedCashmere
2020/08/26 全球购物
介绍一下HDLC(High-Level Data Link Control)高层数据链路协议
2012/01/21 面试题
银行类自荐信
2014/02/04 职场文书
优秀实习生感言
2014/03/01 职场文书
年终奖发放方案
2014/06/02 职场文书
人力资源本科毕业生求职信
2014/06/04 职场文书
2014年乡镇妇联工作总结
2014/12/02 职场文书
新农村建设指导员工作总结
2015/08/13 职场文书
python读取pdf格式文档的实现代码
2021/04/01 Python
Python如何把不同类型数据的json序列化
2021/04/30 Python